Porter Stemming in Apache Spark ML

13 December 2015

As I have been playing with Apache Spark ML and needed a stemming algorithm I decided to have a go and write a custom transformer myself.

As of Spark 1.5.2 Stemming has not been introduced (should be in 1.7.0) but I have taken the Porter Stemmer Algorithm implemented in Scala by the ScalaNLP project and wrapped it as a Spark Transformer. Unfortunately, you are going to have to build Spark from source to use it.

To do that simply place this this code into:

mllib/src/main/scala/org/apache/spark/ml/feature/PorterStemmer.scala

and build from source:

export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"
mvn -T2C -Pyarn -Phadoop-2.6 -Dhadoop.version=2.6.0 -Phive -Phive-thriftserver -DskipTests package

Once complete copy the spark-assembly jar to overwrite Spark’s default jar:

cp assembly/target/scala-2.10/spark-assembly-1.6.0-hadoop2.6.0.jar $SPARK_HOME/lib

Here is the code:

On GitHub: PorterStemmer.scala

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml.feature

import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
import scala.Some



private[spark] object PorterStem {
  /*
   Copyright 2009 David Hall, Daniel Ramage
   Licensed under the Apache License, Version 2.0 (the "License")
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
  */

  /**
   * Converts words to their stemmed form using the classic Porter stemming
   * algorithm.
   *
   * @author dlwh
   */

  def apply(w: String) = {
    if (w.length < 3) w
    else {
      val ret = {
        w.replaceAll("([aeiou])y", "$1Y").replaceAll("^y", "Y")
      }
      step5(step4(step3(step2(step1(ret)))))
    }
  }

  private def step1(w: String) = step1c(step1b(step1a(w)))

  // get rid of s's
  private def step1a(w: String) = {
    if (w.endsWith("sses") || w.endsWith("ies"))
      w.substring(0, w.length - 2)
    else if (w.endsWith("s") && w.charAt(w.length - 2) != 's')
      w.substring(0, w.length - 1)
    else w
  }

  private def step1b(w: String) = {
    //println(w + " " + m(w));
    def extra(w: String) = {
      if (w.endsWith("at") || w.endsWith("bl") || w.endsWith("iz")) w + 'e'
      // double consonant:
      else if (doublec(w) && !("lsz".contains(w.last))) w.substring(0, w.length - 1);
      else if (m(w) == 1 && cvc(w)) w + "e"
      else w
    }

    if (w.endsWith("eed")) {
      if (m(w.substring(0, w.length - 3)) > 0)
        w.substring(0, w.length - 1)
      else w
    } else if (w.endsWith("ed")) {
      if (w.indexWhere(isVowel) < (w.length - 2)) extra(w.substring(0, w.length - 2))
      else w
    } else if (w.endsWith("ing")) {
      if (w.indexWhere(isVowel) < (w.length - 3)) extra(w.substring(0, w.length - 3))
      else w
    } else w
  }

  def step1c(w: String) = {
    //println(w + " " + m(w));
    if ((w.last == 'y' || w.last == 'Y') && w.indexWhere(isVowel) < w.length - 1) {
      w.substring(0, w.length - 1) + 'i'
    } else w
  }

  private def replaceSuffix(w: String, suffix: String, repl: String) = {
    if (w endsWith suffix) Some((w.substring(0, w.length - suffix.length), repl))
    else None
  }

  private val mgt0 = {
    (w: (String, String)) => m(w._1) > 0
  }
  private val mgt1 = {
    (w: (String, String)) => m(w._1) > 1
  }

  private def step2(w: String) = {
    //println(w + " " + m(w));
    if (w.length < 3) w
    else {
      val opt = w(w.length - 2) match {
        case 'a' => replaceSuffix(w, "ational", "ate").orElse(replaceSuffix(w, "tional", "tion"))
        case 'c' =>
          replaceSuffix(w, "enci", "ence").orElse(replaceSuffix(w, "anci", "ance"))
        case 'e' => replaceSuffix(w, "izer", "ize")
        case 'g' => replaceSuffix(w, "logi", "log")
        case 'l' => replaceSuffix(w, "bli", "ble") orElse {
          replaceSuffix(w, "alli", "al")
        } orElse {
          replaceSuffix(w, "entli", "ent")
        } orElse {
          replaceSuffix(w, "eli", "e")
        } orElse {
          replaceSuffix(w, "ousli", "ous")
        }
        case 'o' => replaceSuffix(w, "ization", "ize") orElse {
          replaceSuffix(w, "ator", "ate")
        } orElse {
          replaceSuffix(w, "ation", "ate")
        }
        case 's' => replaceSuffix(w, "alism", "al") orElse {
          replaceSuffix(w, "iveness", "ive")
        } orElse {
          replaceSuffix(w, "fulness", "ful")
        } orElse {
          replaceSuffix(w, "ousness", "ous")
        }
        case 't' =>
          replaceSuffix(w, "aliti", "al") orElse {
            replaceSuffix(w, "iviti", "ive")
          } orElse {
            replaceSuffix(w, "biliti", "ble")
          }
        case _ => None
      }
      opt.filter(mgt0).map {
        case (a, b) => a + b
      }.getOrElse(w)
    }
  }

  private def step3(w: String) = {
    //println(w + " " + m(w));
    if (w.length < 3) w
    else {
      val opt = w.last match {
        case 'e' =>
          replaceSuffix(w, "icate", "ic") orElse {
            replaceSuffix(w, "alize", "al")
          } orElse {
            replaceSuffix(w, "ative", "")
          }
        case 'i' => replaceSuffix(w, "iciti", "ic")
        case 'l' => replaceSuffix(w, "ical", "ic").orElse(replaceSuffix(w, "ful", ""))
        case 's' => replaceSuffix(w, "ness", "")
        case _ => None
      }
      opt.filter(mgt0).map {
        case (a, b) => a + b
      }.getOrElse(w)
    }
  }

  private def step4(w: String) = {
    //println(w + " " + m(w));
    if (w.length < 3)
      w
    else {
      val opt = w(w.length - 2) match {
        case 'a' => replaceSuffix(w, "al", "")
        case 'c' => replaceSuffix(w, "ance", "").orElse(replaceSuffix(w, "ence", ""))
        case 'e' => replaceSuffix(w, "er", "")
        case 'i' => replaceSuffix(w, "ic", "")
        case 'l' => replaceSuffix(w, "able", "").orElse(replaceSuffix(w, "ible", ""))
        case 'n' => replaceSuffix(w, "ant", "") orElse {
          replaceSuffix(w, "ement", "")
        } orElse {
          //println("here")
          replaceSuffix(w, "ment", "")
        } orElse {
          //println("hereX")
          replaceSuffix(w, "ent", "")
        }
        case 'o' => replaceSuffix(w, "ion", "").filter(a => a._1.endsWith("t") || a._1.endsWith("s")).
          orElse(replaceSuffix(w, "ou", ""))
        case 's' => replaceSuffix(w, "ism", "")
        case 't' => replaceSuffix(w, "ate", "").orElse(replaceSuffix(w, "iti", ""))
        case 'u' => replaceSuffix(w, "ous", "")
        case 'v' => replaceSuffix(w, "ive", "")
        case 'z' => replaceSuffix(w, "ize", "")
        case _ => None
      }
      opt.filter(mgt1).map {
        case (a, b) => a + b
      }.getOrElse(w)
    }
  }

  private def step5(w: String) = {
    //println(w + " " + m(w));
    if (w.length < 3) w
    else
      step5b(step5a(w))
  }


  private def step5a(w: String) = {
    if (w.length < 3) w
    else
    if (w.last == 'e') {
      val n = m(w)
      if (n > 1) w.substring(0, w.length - 1)
      else if (n == 1 && !cvc(w.substring(0, w.length - 1))) w.substring(0, w.length - 1)
      else w
    }
    else {
      w
    }
  }

  private def step5b(w: String) = {
    if (w.last == 'l' && doublec(w) && m(w) > 1) w.substring(0, w.length - 1)
    else w
  }

  def m(w: String): Int = {
    val firstV = w.indexWhere(isVowel)
    if (firstV == -1) 0
    else {
      var m = 0
      var x: Seq[Char] = w.substring(firstV)
      if (x.isEmpty) m
      else {
        while (!x.isEmpty) {
          x = x.dropWhile(isVowel)
          if (x.isEmpty) return m
          m += 1
          if (m > 1) return m; // don't need anything bigger than this.
          x = x.dropWhile(isConsonant)
        }
        m
      }
    }
  }

  private def cvc(w: String) = (
    w.length > 2
      && isConsonant(w.last)
      && !("wxY" contains w.last)
      && isVowel(w(w.length - 2))
      && isConsonant(w.charAt(w.length - 3))
    )

  private def doublec(w: String) = {
    (w.length > 2 && w.last == w.charAt(w.length - 2) && isConsonant(w.last))
  }

  def isConsonant(letter: Char) = !isVowel(letter)

  def isVowel(letter: Char) = "aeiouy" contains letter
}


/**
 * :: Experimental ::
 * Maps a sequence of terms to their term frequencies using the hashing trick.
 */
@Experimental
class PorterStemmer(override val uid: String)
  extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {

  def this() = this(Identifiable.randomUID("porterStemmer"))

  /** @group setParam */
  def setInputCol(value: String): this.type = set(inputCol, value)

  /** @group setParam */
  def setOutputCol(value: String): this.type = set(outputCol, value)

  override def transform(dataset: DataFrame): DataFrame = {
    val outputSchema = transformSchema(dataset.schema)

    val t = udf { terms: Seq[String] =>
      terms.map { word =>
        (PorterStem.apply(word))
      }
    }
    val metadata = outputSchema($(outputCol)).metadata
    dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
  }

  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    require(inputType.sameType(ArrayType(StringType)),
      s"Input type must be ArrayType(StringType) but got $inputType.")
    val outputFields = schema.fields :+
      StructField($(outputCol), inputType, schema($(inputCol)).nullable)
    StructType(outputFields)
  }

  override def copy(extra: ParamMap): PorterStemmer = defaultCopy(extra)
}

@Since("1.6.0")
object PorterStemmer extends DefaultParamsReadable[PorterStemmer] {

  @Since("1.6.0")
  override def load(path: String): PorterStemmer = super.load(path)
}
If you find an error please raise a pull request.