import scala.collection.immutable.HashMap
class Trie[T](leaves: Map[Char, Trie[T]], endPoint: Boolean, childrenSize: Int, data: Option[T]) {
def this(endPoint: Boolean, data: Option[T]) = this(Map[Char, Trie[T]](), endPoint, 0, data)
def add(word: String, value: T): Trie[T] = {
if (word.isEmpty) {
return new Trie[T](leaves, true, childrenSize, Option[T](value))
}
val firstLetter = word.charAt(0)
val tail = word.substring(1)
if (leaves.contains(firstLetter)) {
val newLeaf = firstLetter -> leaves.get(firstLetter).get.add(tail, value)
val newLeaves = leaves + newLeaf
val size = countWords(newLeaves)
new Trie[T](newLeaves, endPoint, size, data)
} else {
val newLeaf = firstLetter -> buildFromString(tail, Option[T](value))
val newLeaves = leaves + newLeaf
val size = countWords(newLeaves)
new Trie[T](newLeaves, endPoint, size, data)
}
}
def remove(word: String): Trie[T] = {
if (word.isEmpty) {
return new Trie[T](leaves, false, childrenSize, None)
}
val firstLetter = word.charAt(0)
val tail = word.substring(1)
if (leaves.contains(firstLetter)) {
val newLeaf = firstLetter -> leaves.get(firstLetter).get.remove(tail)
if (newLeaf._2.size() == 0) {
val newLeaves = leaves - firstLetter
val size = countWords(newLeaves)
new Trie[T](newLeaves, endPoint, size, data)
} else {
val newLeaves = leaves + newLeaf
val size = countWords(newLeaves)
new Trie[T](newLeaves, endPoint, size, data)
}
} else {
this
}
}
def size(): Int = {
if (endPoint) {
childrenSize + 1
} else {
childrenSize
}
}
def print(indent: Int = 0) {
if (endPoint) println(" " * indent + "$(" + data.get + ")$")
for (kv <- leaves) {
println(" " * indent + kv._1)
kv._2.print(indent + 1)
}
}
def contains(word: String): Boolean = {
get(word) != None
}
def get(word: String): Option[T] = {
if (word.isEmpty) {
if (endPoint) {
return data
} else {
return None
}
}
val firstLetter = word.charAt(0)
if (leaves.contains(firstLetter)) {
return leaves.get(firstLetter).get.get(word.substring(1))
}
None
}
private def countWords(leaves: Map[Char, Trie[T]]) = leaves.foldLeft(0)(_ + _._2.size())
private def buildFromString(word: String, value: Option[T]): Trie[T] = {
if (word.isEmpty) {
new Trie[T](true, value)
} else {
new Trie[T](Map(word.charAt(0) -> buildFromString(word.substring(1), value)), false, 1, None)
}
}
}
object Trie {
def apply[T](): Trie[T] = {
new Trie[T](false, None)
}
}
class MinimizedTrie[T](leaves: HashMap[String, MinimizedTrie[T]], endPoint: Boolean, childrenSize: Int, data: Option[T]) {
def this(endPoint: Boolean, data: Option[T]) = this(HashMap[String, MinimizedTrie[T]](), endPoint, 0, data)
def add(word: String, value: T): MinimizedTrie[T] = {
if (word.isEmpty) {
return new MinimizedTrie[T](leaves, true, childrenSize, Option[T](value))
}
val maxPrefix = findMaxPrefix(word)
if (maxPrefix.isEmpty) {
val newLeaf = word -> new MinimizedTrie[T](true, Option[T](value))
return new MinimizedTrie[T](leaves + newLeaf, endPoint, childrenSize + 1, data)
}
val keyVal = leaves.find(kv => kv._1.startsWith(maxPrefix)).get
val edge = keyVal._1
val trie = keyVal._2
if (edge == maxPrefix) {
val newLeaf = edge -> trie.add(word.substring(edge.length), value)
val newLeaves = leaves + newLeaf
val size = countWords(newLeaves)
new MinimizedTrie[T](newLeaves, endPoint, size, data)
} else {
val end = edge.substring(maxPrefix.length)
val middleVertex = new MinimizedTrie[T](HashMap(end -> trie), false, trie.size(), None)
val newLeaf = maxPrefix -> middleVertex.add(word.substring(maxPrefix.length), value)
val newLeaves = (leaves - edge) + newLeaf
val size = countWords(newLeaves)
new MinimizedTrie[T](newLeaves, endPoint, size, data)
}
}
def remove(word: String): MinimizedTrie[T] = {
if (word.isEmpty) {
return new MinimizedTrie(leaves, false, childrenSize, None)
}
val maxPrefix = findMaxPrefix(word)
if (maxPrefix.isEmpty) {
return this
}
val keyVal = leaves.find(kv => kv._1.startsWith(maxPrefix)).get
val edge = keyVal._1
val trie = keyVal._2
if (edge == maxPrefix) {
val newLeaf = edge -> trie.remove(word.substring(edge.length))
if (newLeaf._2.size() == 0) {
val newLeaves = leaves - edge
val size = countWords(newLeaves)
new MinimizedTrie[T](newLeaves, endPoint, size, data)
} else {
newLeaf._2.leavesNumber() match {
case 1 =>
val nextLeaf = newLeaf._2.getSingleLeaf()
if (newLeaf._2.size() == nextLeaf._2.size()) {
val leaf = (edge + nextLeaf._1) -> nextLeaf._2
val newLeaves = (leaves - edge) + leaf
val size = countWords(newLeaves)
new MinimizedTrie[T](newLeaves, endPoint, size, data)
} else {
val newLeaves = leaves + newLeaf
val size = countWords(newLeaves)
new MinimizedTrie[T](newLeaves, endPoint, size, data)
}
case _ =>
val newLeaves = leaves + newLeaf
val size = countWords(newLeaves)
new MinimizedTrie[T](newLeaves, endPoint, size, data)
}
}
} else {
this
}
}
def contains(word: String): Boolean = {
get(word) != None
}
def get(word: String): Option[T] = {
if (word.isEmpty) {
if (endPoint) {
return data
} else {
return None
}
}
val maxPrefix = findMaxPrefix(word)
if (maxPrefix.isEmpty) {
return None
}
val keyVal = leaves.find(kv => kv._1.startsWith(maxPrefix)).get
if (keyVal._1 != maxPrefix) {
return None
}
keyVal._2.get(word.substring(maxPrefix.length))
}
def size(): Int = {
if (endPoint) {
childrenSize + 1
} else {
childrenSize
}
}
def print(indent: Int = 0) {
if (endPoint) println(" " * indent + "$(" + data.get + ")$")
for (kv <- leaves) {
println(" " * indent + kv._1)
kv._2.print(indent + kv._1.length)
}
}
private def leavesNumber() = leaves.size
private def getSingleLeaf() = leaves.last
private def countWords(leaves: HashMap[String, MinimizedTrie[T]]) = leaves.foldLeft(0)(_ + _._2.size())
private def findMaxPrefix(word: String): String = {
for (edge <- leaves.keys) {
val prefix = (word, edge).zipped.takeWhile(Function.tupled(_ == _)).unzip._1.mkString
if (!prefix.isEmpty) {
return prefix
}
}
""
}
}
object MinimizedTrie {
def apply[T](): MinimizedTrie[T] = {
new MinimizedTrie[T](false, None)
}
}
//var t = Trie[Int]()
var t = MinimizedTrie[Int]()
assert(!t.contains(""))
assert(!t.contains("a"))
assert(t.size() == 0)
t = t.add("abc", 0)
assert(t.size() == 1)
t = t.add("ab", 1)
assert(t.size() == 2)
t = t.add("hello", 2)
assert(t.size() == 3)
t = t.add("c", 3)
assert(t.size() == 4)
assert(!t.contains(""))
assert(!t.contains("a"))
assert(t.contains("ab"))
assert(t.contains("abc"))
assert(!t.contains("cabc"))
assert(!t.contains("hell"))
assert(t.contains("hello"))
println()
assert(t.size() == 4)
t = t.remove("a")
assert(t.size() == 4)
t = t.remove("abc")
assert(t.size() == 3)
t = t.remove("aaa")
assert(t.size() == 3)
t = t.add("hell", 4)
t = t.remove("hell")
assert(t.size() == 3)
assert(!t.contains(""))
assert(!t.contains("a"))
assert(t.contains("ab"))
assert(!t.contains("abc"))
assert(!t.contains("cabc"))
assert(!t.contains("hell"))
assert(t.contains("hello"))
println()
t = t.add("hell", 5)
t = t.add("hello", 10)
t.print()