abstract class SegmentTree[T]() { def op: (T, T) => T def size: Int def value: T def left: SegmentTree[T] def right: SegmentTree[T] def buildTree(arr: Array[T]): SegmentTree[T] = { def recBuild(subArr: Array[T], left: Int, right: Int): SegmentTree[T] = { if (left < right) { val leftNode = recBuild(subArr, left, (left + right) / 2) val rightNode = recBuild(subArr, (left + right) / 2 + 1, right) RealNode[T](op(leftNode.value, rightNode.value), leftNode, rightNode, op, right - left + 1) } else RealNode[T](subArr(left), FakeNode[T](op), FakeNode[T](op), op, right - left + 1) } recBuild(arr, 0, arr.length - 1) } def dumpSubTree(node: SegmentTree[T] = this, tabCount: Int = 0): Unit = node match { //debug method case RealNode(value, left, right, op, size) => val space = " " * tabCount println(space + "→" + value) dumpSubTree(left, tabCount + 2) dumpSubTree(right, tabCount + 2) case _ => } def request(left: Int, right: Int): T = { def descend(node: SegmentTree[T], l: Int, r: Int, left: Int, right: Int): T = { //debug print print("Node(" + l + ", " + r + ") - Request(" + left + ", " + right + ")") if (l == left && r == right) println("■") else println() //=========== val m = (l + r) / 2 if (l == left && r == right) node.value else if (right < m + 1) descend(node.left, l, m, left, right) else if (left > m) descend(node.right, m + 1, r, left, right) else op(descend(node.left, l, m, left, m), descend(node.right, m + 1, r, m + 1, right)) } print("▶ ") descend(this, 0, size - 1, left, right) } def change(index: Int, value: T): SegmentTree[T] = { def descend(node: SegmentTree[T], l: Int, r: Int, index: Int, value: T): SegmentTree[T] = { if (l == r) RealNode(value, FakeNode(op), FakeNode(op), op, 1) else if (index < (l + r) / 2 + 1) { val newNode = descend(node.left, l, (l + r) / 2, index, value) RealNode(op(newNode.value, node.right.value), newNode, node.right, op, newNode.size + node.right.size) } else { val newNode = descend(node.right, (l + r) / 2 + 1, r, index, value) RealNode(op(node.left.value, newNode.value), node.left, newNode, op, newNode.size + node.left.size) } } descend(this, 0, size - 1, index, value) } } //============================================================================== case class RealNode[T](val value: T, val left: SegmentTree[T], val right: SegmentTree[T], val op: (T, T) => T, val size: Int = 0) extends SegmentTree[T] { } //============================================================================== case class FakeNode[T](val op: (T, T) => T) extends SegmentTree[T] { def value = sys.error("Ooops... you are in the fake node :(") val left = this val right = this val size = 0 } //============================================================================== def min[T <: Comparable[T]](a: T, b: T): T = if (a.compareTo(b) < 0) a else b def sum(a: String, b: String): String = a + b //============================================================================== var RMQ: SegmentTree[Integer] = FakeNode[Integer](min) var RSQ: SegmentTree[String] = FakeNode[String](sum) RMQ = RMQ.buildTree(Array[Integer] (0, 1, 2, 3, 4)) RSQ = RSQ.buildTree(Array[String] ("a", "b", "c", "d", "e")) println("RMQ tree:") RMQ.dumpSubTree() println("RSQ tree:") RSQ.dumpSubTree() println("RMQ requests:") println("min in [0, 3] = " + RMQ.request(0, 3)) println("min in [0, 2] = " + RMQ.request(0, 2)) println("min in [2, 2] = " + RMQ.request(2, 2)) println("RSQ requests:") println("sum in [1, 3] = " + RSQ.request(1, 3)) println("sum in [2, 3] = " + RSQ.request(2, 3)) println("Change test:") println("Change (index = 2, value = 'New value')") RSQ = RSQ.change(2, "New value") RSQ.dumpSubTree()