Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.nio.file.{FileAlreadyExistsException, Files, Paths}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import scala.collection.JavaConverters._

import org.rocksdb._

Expand Down Expand Up @@ -68,35 +69,109 @@ private[master] class RocksDBPersistenceEngine(
*
* https://github.com/facebook/rocksdb/wiki/Compression#configuration
*/
private val options = new Options()
private def createCfOptions(): ColumnFamilyOptions = {
new ColumnFamilyOptions()
.setBottommostCompressionType(CompressionType.ZSTD_COMPRESSION)
.setCompressionType(CompressionType.LZ4_COMPRESSION)
.setTableFormatConfig(tableFormatConfig)
}

private val KNOWN_PREFIXES = Seq("app_", "driver_", "worker_")

private val dbOptions = new DBOptions()
.setCreateIfMissing(true)
.setBottommostCompressionType(CompressionType.ZSTD_COMPRESSION)
.setCompressionType(CompressionType.LZ4_COMPRESSION)
.setTableFormatConfig(tableFormatConfig)
.setCreateMissingColumnFamilies(true)

// Discover existing column families first to ensure we can open an existing DB
private val existingCfs = try {
RocksDB.listColumnFamilies(new Options(), path.toString).asScala.map(new String(_, UTF_8))
} catch {
case _: RocksDBException => Seq(new String(RocksDB.DEFAULT_COLUMN_FAMILY, UTF_8))
}

private val allCfs = (Seq(new String(RocksDB.DEFAULT_COLUMN_FAMILY, UTF_8)) ++ KNOWN_PREFIXES ++ existingCfs).distinct

private val cfDescriptors = new java.util.ArrayList[ColumnFamilyDescriptor]()
allCfs.foreach { cfName =>
cfDescriptors.add(new ColumnFamilyDescriptor(cfName.getBytes(UTF_8), createCfOptions()))
}

private val cfHandles = new java.util.ArrayList[ColumnFamilyHandle]()
private val db: RocksDB = RocksDB.open(dbOptions, path.toString, cfDescriptors, cfHandles)

private val cfHandleMap: Map[String, ColumnFamilyHandle] = {
allCfs.zipWithIndex.flatMap { case (name, idx) =>
if (KNOWN_PREFIXES.contains(name)) Some(name -> cfHandles.get(idx))
else None
}.toMap
}

private val db: RocksDB = RocksDB.open(options, path.toString)
private val defaultCFHandle = cfHandles.get(allCfs.indexOf(new String(RocksDB.DEFAULT_COLUMN_FAMILY, UTF_8)))

private def getCFHandle(name: String): ColumnFamilyHandle = {
cfHandleMap.find { case (prefix, _) => name.startsWith(prefix) }
.map(_._2)
.getOrElse(defaultCFHandle)
}

migrateOldData()

private def migrateOldData(): Unit = {
val iter = db.newIterator(defaultCFHandle)
val writeBatch = new WriteBatch()
var count = 0
try {
iter.seekToFirst()
while (iter.isValid) {
val key = iter.key()
val keyStr = new String(key, UTF_8)
val handle = cfHandleMap.find { case (prefix, _) => keyStr.startsWith(prefix) }.map(_._2)
handle.foreach { h =>
writeBatch.put(h, key, iter.value())
writeBatch.delete(defaultCFHandle, key)
count += 1
}
iter.next()
}
} finally {
iter.close()
}
if (count > 0) {
logInfo(s"Migrated $count records from default column family to specific column families.")
val writeOptions = new WriteOptions().setSync(true)
try {
db.write(writeOptions, writeBatch)
} finally {
writeOptions.close()
}
}
writeBatch.close()
}

override def persist(name: String, obj: Object): Unit = {
val serialized = serializer.newInstance().serialize(obj)
val cfHandle = getCFHandle(name)
if (serialized.hasArray) {
db.put(name.getBytes(UTF_8), serialized.array())
db.put(cfHandle, name.getBytes(UTF_8), serialized.array())
} else {
val bytes = new Array[Byte](serialized.remaining())
serialized.get(bytes)
db.put(name.getBytes(UTF_8), bytes)
db.put(cfHandle, name.getBytes(UTF_8), bytes)
}
}

override def unpersist(name: String): Unit = {
db.delete(name.getBytes(UTF_8))
db.delete(getCFHandle(name), name.getBytes(UTF_8))
}

override def read[T: ClassTag](name: String): Seq[T] = {
override def read[T: ClassTag](prefix: String): Seq[T] = {
val result = new ArrayBuffer[T]
val iter = db.newIterator()
val cfHandle = getCFHandle(prefix)
val iter = db.newIterator(cfHandle)
try {
iter.seek(name.getBytes(UTF_8))
while (iter.isValid && new String(iter.key()).startsWith(name)) {
val prefixBytes = prefix.getBytes(UTF_8)
iter.seek(prefixBytes)
while (iter.isValid && startsWith(iter.key(), prefixBytes)) {
result.append(serializer.newInstance().deserialize[T](ByteBuffer.wrap(iter.value())))
iter.next()
}
Expand All @@ -105,4 +180,21 @@ private[master] class RocksDBPersistenceEngine(
}
result.toSeq
}

private def startsWith(key: Array[Byte], prefix: Array[Byte]): Boolean = {
if (key.length < prefix.length) return false
var i = 0
while (i < prefix.length) {
if (key(i) != prefix(i)) return false
i += 1
}
true
}

override def close(): Unit = {
cfHandles.asScala.foreach(_.close())
if (db != null) {
db.close()
}
}
}