domenica 17 luglio 2011

algoritmo k-NN: meglio in Scala

Stavo leggendo questo libro su machine learning. È un buon libro ma mi ero stufato della verbosità di Python, con cui sono presentati gli algoritmi, quindi ho scritto un'implementazione dell'algoritmo k-NN in Scala:
package knn

class Knn[TData, TClass,
          TDistance <% Ordered[TDistance] ]
(
  val k : Int,
  val dataset : List[TData],
  val getClassOfData : TData => TClass,
  val getDistance : (TData, TData) => TDistance
)
{
  if(k < 1)
    throw new scala.IllegalArgumentException("k must be positive")

  def classify(data : TData) : TClass = {

    // Compute distances between tested data and dataset items
    val distances = for( d <- dataset ) yield (d, getDistance(d, data))

    // sort data by distance
    val sorted = distances sortBy ( d => d._2 )

    // take first k items
    val firstk = sorted take k

    // group by labels
    val classes = firstk groupBy (d => getClassOfData(d._1))

    // get most frequent label
    val classification = classes maxBy (g => g._2.length)

    classification._1
  }
}
Wow, ho scritto questa classe in dieci minuti funziona con ogni tipo di dataset dove sia definita una distanza e ovviamente una classificazione, come si può vedere con questo test (scritto con ScalaTest):
package knn

import org.scalatest._
import org.scalatest.matchers._
import scala.math._

class KnnTest extends FlatSpec with ShouldMatchers {

  "2nn" should " classify correctly with a sample near two well-known labelled items" in {
    val alg = new Knn(2,
          List((3, "ok"), (18, "ok"), (21, "high"), (64, "high")),
          (i : (Int, String)) => i._2,
          (i1: (Int, String), i2: (Int, String)) => abs(i1._1 - i2._1))

    val classified = alg.classify((4, ""))

    classified should equal ("ok")
  }
}
Naturalmente il codice può essere ottimizzato e generalizzato ma mostra come si possa scrivere codice veramente disaccoppiato e generale in Scala.

Nessun commento: