Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ libraryDependencies += "com.lambdaworks" %% "jacks" % "2.3.3"

libraryDependencies += "org.scalatest" % "scalatest_2.10" % "2.2.1" % "test"

libraryDependencies += "com.databricks" % "spark-csv_2.10" % "1.2.0"


libraryDependencies ++= Seq(
"org.eclipse.jetty.orbit" % "javax.servlet" % "3.0.0.v201112011016" artifacts Artifact("javax.servlet", "jar", "jar"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package sampleclean.clean.outlierremoval

import org.apache.spark.sql.DataFrame
import sampleclean.api.WorkingSet

/**
* @author Viraj Mahesh
*
*/
trait OutlierRemovalAlgorithm {

/**
* Removes outliers from a dataset. The algorithm for identifying outliers
* depends on the implementing class.
*
* @param dataFrame The dataset that we are cleaning
* @return A new DataFrame with outliers removed
*/
def removeOutliers(dataFrame: DataFrame): DataFrame
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package sampleclean.clean.outlierremoval

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, DataFrame}
import sampleclean.api.SampleCleanContext

/**
* Detects outliers based on deviation from the mean
*
* @author Viraj Mahesh
*
* @param scc The sample clean context
* @param maxDev An observation x is an outlier iff: abs(x - mean) > maxDev * stdDev
* @param colName The index of the column that will be used to classify a row
* as an outlier
*/
class StdDeviationFilter(scc: SampleCleanContext,
maxDev: Double,
colName: String) extends OutlierRemovalAlgorithm with Serializable {

/**
* Convert x to a Double
*/
def toDouble(x: Any): Double = {
x match {
case n: Number => n.doubleValue()
}
}

override def removeOutliers(dataFrame: DataFrame): DataFrame = {
val colIdx = dataFrame.columns.indexOf(colName) // Get the column index of this column
val univariateData: RDD[Double] = dataFrame.map({case x: Row => toDouble(x(colIdx))})

// Calculate the mean and standard deviation of the column
val mean: Double = univariateData.mean
val stdDev: Double = univariateData.stdev

// Only keep those columns that are less than maxDev standard deviations from the mean
dataFrame.filter(dataFrame.col(colName) > mean - (maxDev * stdDev))
.filter(dataFrame.col(colName) < mean + (maxDev * stdDev))
}
}

object StdDeviationFilter {
/**
* Creates a new StdDeviationFilter and applies it on the dataset
*/
def removeOutliers(scc: SampleCleanContext, maxDev: Double,
columnName: String, dataFrame: DataFrame) = {
new StdDeviationFilter(scc, maxDev, columnName).removeOutliers(dataFrame)
}
}
5 changes: 5 additions & 0 deletions src/test/resources/students.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
'A',20,3.00
'B',22,3.23
'C',25,4.00
'D',30,4.25
'E',70,3.10
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package sampleclean.clean.outlierremoval

import org.apache.spark.sql.types._
import org.apache.spark.sql.SQLContext
import org.scalatest.FunSuite
import sampleclean.clean.LocalSCContext

/**
* @author Viraj Mahesh
*/
class StdDeviationFilterTest extends FunSuite with LocalSCContext with Serializable {

// Schema of the students table
val SCHEMA = StructType(List(
StructField("name", StringType, true),
StructField("age", IntegerType , true),
StructField("gpa", DoubleType, true)))

// Properties passed into the databricks CSV loader
val PROPERTIES = Map("path" -> "./src/test/resources/students.csv")

test("outlier removal integer column") {
withSampleCleanContext { scc =>
val sc = scc.getSparkContext()
val sqlContext = new SQLContext(sc)

val data = sqlContext.load("com.databricks.spark.csv", schema = SCHEMA, PROPERTIES)
val filteredData = StdDeviationFilter.removeOutliers(scc, 1.0, "age", data)

// Find the index of the age column and only retain the age column
val colIdx = filteredData.columns.indexOf("age")
val remainingValues = filteredData.map(x => x.getInt(colIdx)).collect()

assert(remainingValues.length == 4)

assert(remainingValues.contains(20))
assert(remainingValues.contains(22))
assert(remainingValues.contains(25))
assert(remainingValues.contains(30))
}
}

test("outlier removal double column") {
withSampleCleanContext { scc =>
val sc = scc.getSparkContext()
val sqlContext = new SQLContext(sc)

val data = sqlContext.load("com.databricks.spark.csv", schema = SCHEMA, PROPERTIES)
val filteredData = StdDeviationFilter.removeOutliers(scc, 1.0, "gpa", data)

// Find the index of the GPA column and only retain the gpa column
val colIdx = filteredData.columns.indexOf("gpa")
val remainingValues = filteredData.map(x => x.getDouble(colIdx)).collect()

assert(remainingValues.length == 3)

assert(remainingValues.contains(3.23))
assert(remainingValues.contains(4.00))
assert(remainingValues.contains(3.10))
}
}
}