Tuesday 12 January 2010

Scala: implementing a "did you mean..?" spelling corrector

I was looking at Scala again and decided to implement Peter Norvig's algorithm for suggesting spelling correction suggestions. I suggest you go read How to Write a Spelling Corrector for the clever stuff.

This implementation is limited to the English alphabet. You'll need the big.txt file or a similar set of training data.

Spell check in Java

Since I don't really know Python and am only getting to grips with Scala, I decided to implement the code in Java first.

I started by defining an interface to write to:

//Spell.java
public interface Spell {
  /**
   * Takes a normalised (e.g. all lower case) word term and returns the most
   * likely correction. If the argument is known, it is returned. If no
   * corrections can be found, the argument is returned.
   
   @param term
   *          the word to check
   @return the most likely correction
   */
  public String correct(String term);
}

The implementation of this interface was written for clarity and efficiency rather than brevity.

//JSpell.java
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class JSpell implements Spell {
  private final Map<String, Integer> dictionary;

  public JSpell(Map<String, Integer> dictionary) {
    this.dictionary = dictionary;
  }

  private void edits1(String term, Receiver output) {
    StringBuilder buffer = new StringBuilder(term.length() 1);
    // deletion (missing character typo)
    for (int i = 0; i < term.length(); i++) {
      String pre = term.substring(0, i);
      String post = term.substring(i + 1);
      String word = clear(buffer).append(pre).append(post).toString();
      output.add(word);
    }
    // transposition (character-swap typo)
    for (int i = 0; i < term.length() 1; i++) {
      String pre = term.substring(0, i);
      char swap1 = term.charAt(i + 1);
      char swap2 = term.charAt(i);
      String post = term.substring(i + 2);
      String word = clear(buffer).append(pre).append(swap1).append(swap2)
          .append(post).toString();
      output.add(word);
    }
    // replacement (mistyped character)
    for (int i = 0; i < term.length(); i++) {
      String pre = term.substring(0, i);
      String post = term.substring(i + 1);
      for (char c = 'a'; c <= 'z'; c++) {
        String word = clear(buffer).append(pre).append(c).append(post)
            .toString();
        output.add(word);
      }
    }
    // insertion (extra character typo)
    for (int i = 0; i < term.length() 1; i++) {
      String pre = term.substring(0, i);
      String post = term.substring(i);
      for (char c = 'a'; c <= 'z'; c++) {
        String word = clear(buffer).append(pre).append(c).append(post)
            .toString();
        output.add(word);
      }
    }
  }

  private StringBuilder clear(StringBuilder buffer) {
    return buffer.delete(0, buffer.length());
  }

  @Override
  public String correct(final String term) {
    if (dictionary.containsKey(term)) {
      return term;
    }
    final Result result = new Result(term);
    final Set<String> editSet = new HashSet<String>();
    edits1(term, new Receiver() {
      @Override
      public void add(String word) {
        result.add(word);
        if (!result.hasCandidate()) {
          editSet.add(word);
        }
      }
    });
    if (result.hasCandidate()) {
      return result.word;
    }

    for (String edit : editSet) {
      edits1(edit, result);
    }
    return result.word;
  }

  private class Result implements Receiver {
    private String word;
    private int weight;

    public Result(String word) {
      this.word = word;
    }

    @Override
    public void add(String term) {
      Integer weight = dictionary.get(term);
      if (weight != null && weight > this.weight) {
        this.weight = weight;
        word = term;
      }
    }

    public boolean hasCandidate() {
      return weight > 0;
    }
  }

  private static interface Receiver {
    public void add(String word);
  }
}

This code works by pushing correction candidates into a Receiver type which determines which of the candidates it has been passed is the best (that is, has the highest frequency in the dictionary). The correction sequence goes like this:

  • If the term is in the dictionary, return the term.
  • Else if a candidate one character different to the term is in the dictionary, return the candidate with the highest weight.
  • Else if a candidate two characters different to the term is in the dictionary, return the candidate with the highest weight.
  • Else return the term.

Spell check in Scala

Rather than do a straight port of the Java code, I tried to implement this version in idiomatic Scala. Still, I imagine I've made some less than stellar implementation choices to the eyes of seasoned Scala developers.

//SSpell.scala
class SSpell(data: Map[String,Int]) {  
  private val dictionary = new FrequencyMap(data)
  private val alphabet = 'a' to 'z'
  
  /**Map that will return a default value of zero for no entry*/
  private class FrequencyMap[K](m: Map[K, Int]) extends
                                     scala.collection.mutable.HashMap[K, Int] {
    override def default(key: K) = 0
    this ++ m
  }
  
  private def edit1(word: String): Seq[String] = {
    // "hello" becomes ("", "hello"), ("h", "ello"), etc. 
    val splits = (0 to word.length).map(i => (word.take(i), word.drop(i)))
    val deleted = splits.filter(_._2.length > 0)
      .map(tuple => tuple._1 + tuple._2.drop(1))
    val transposed = splits.filter(_._2.length > 1)
      .map(tuple => tuple._1 + tuple._2(1) + tuple._2(0) + tuple._2.drop(2))
    val replaced = splits.filter(_._2.length > 0)
      .flatMap(tuple => alphabet.map(tuple._1 + _ + tuple._2.drop(1)))
    val inserted = splits
      .flatMap(tuple => alphabet.map(tuple._1 + _ + tuple._2))
    deleted ++ transposed ++ replaced ++ inserted 
  }
  
  private def edit2(word: String, edits: Seq[String]) = {
    val edit2 = for(edit <- edits;
                  e2 <- edit1(edit);
                  known = (e2, dictionary(e2));
                  if(known._2>0)) yield known
    edit2.foldLeft((word, 0)){best(_,_)}
  }
  
  private def best(a: (String,Int), b: (String,Int)) = if(b._2 > a._2) b else a
  
  private def fix(word: String, edits: Seq[String]) = {
    val terms = edits.map(term => (term, dictionary(term)))
    val correction = terms.foldLeft((word, 0)){best(_,_)}
    if(correction._2 > 0) correction._1 else edit2(word, edits)._1
  }
  
  def correct(word: String) =
    if(dictionary(word) > 0) word else fix(word, edit1(word))
}

If you aren't familiar with Scala, some of this can look quite daunting. Let's break down this method:

  private def fix(word: String, edits: Seq[String]) = {
    val terms = edits.map(term => (term, dictionary(term)))
    val correction = terms.foldLeft((word, 0)){best(_,_)}
    if(correction._2 > 0) correction._1 else edit2(word, edits)._1
  }

The underscore _ used by itself as a variable is known as place-holder notation and in this context roughly means "the next argument." (foo,bar) is a shortcut for creating a tuple type. map and foldLeft iterate over the contents of the class and apply a function to the elements. map returns another sequence and foldLeft returns a single value. If you are a Java developer, you may grok this rough translation more easily:

  private String fix(String word, Seq<String> edits) {
    class GetBest implements Comparer<Tuple<String, Integer>> {
      @Override
      public Tuple<String, Integer> choose(Tuple<String, Integer> t1,
          Tuple<String, Integer> t2) {
        return best(t1, t2);
      }
    }

    class Lookup implements Mapper<String, Tuple<String, Integer>> {
      @Override
      public Tuple<String, Integer> mapTo(String t) {
        return new Tuple<String, Integer>(t, dictionary.get(t));
      }
    }

    Seq<Tuple<String, Integer>> terms = edits.map(new Lookup());
    Tuple<String, Integer> correction = terms.foldLeft(
        new Tuple<String, Integer>(word, 0)new GetBest());

    return (correction._2 > 0? correction._1 : edit2(word, edits)._1;
  }

Note: this isn't a literal translation of what Scala does - I just want to give a better sense of the logic involved.

Did you mean ..?

The application utilising the spell checker was written in Java and I wanted to just swap in functionality written in Scala. One of the attractions of Scala is the way you can integrate it with existing code and I wanted to try this out.

//DidYouMean.java
import java.io.File;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Scanner;
import java.util.regex.Pattern;

public class DidYouMean {

  public static Map<String, Integer> load(File filethrows IOException {
    Map<String, Integer> dictionary = new HashMap<String, Integer>();
    Scanner scanner = new Scanner(file, "ASCII").useDelimiter(Pattern
        .compile("[^a-zA-Z]++"));
    try {
      while (scanner.hasNext()) {
        String word = scanner.next().toLowerCase(Locale.ENGLISH);
        Integer value = dictionary.get(word);
        dictionary.put(word, value == null : value + 1);
      }
    finally {
      scanner.close();
    }
    if (scanner.ioException() != null) {
      throw scanner.ioException();
    }
    return Collections.unmodifiableMap(dictionary);
  }

  public static void main(String[] argsthrows IOException {
    Map<String, Integer> dictionary = load(new File("big.txt"));
    // Spell checker = new JSpell(dictionary); // Java impl
    Spell checker = new SSpellAdapter(dictionary)// Scala impl
    for (String word : args) {
      String normalised = word.toLowerCase(Locale.ENGLISH);
      String correction = checker.correct(normalised);
      System.out.format("%s: did you mean %s?%n", normalised, correction);
    }
  }
}

There are two obstacles to switching implementations. One: SSpell doesn't implement the Spell interface. Two: Scala collection classes don't inherit from the Java collection classes type hierarchy. A simple adapter class fixes both these problems.

//SSpellAdapter.scala
class SSpellAdapter(javaMap: java.util.Map[String, Int]) extends Spell {
  import  scala.collection.jcl.Conversions
  val scalaMap = Map.empty ++ Conversions.convertMap(javaMap);
  val corrector = new SSpell(scalaMap)
  
  def correct(word: String): String = corrector.correct(word)
}

Here is the application as run from the command line:

X:\test>java -cp C:\Scala\lib\scala-library.jar;. DidYouMean spellin
spellin: did you mean spelling?

Benchmarking

Introducing SSpellAdapter makes it really easy to check the Scala implementation against the Java version. As a bonus, it creates a fairly level playing field for benchmarking.

These are the results for running the checkers against the tests1 data set when run on a 1.8GHz Centrino:

class SSpellAdapter
Time: 12.605887 seconds
Total: 265; Right: 197; Wrong: 68; Unknown: 15; Pct: 74.339623

class JSpell
Time: 3.845939 seconds
Total: 265; Right: 197; Wrong: 68; Unknown: 15; Pct: 74.339623

Don't read too much into the timings on this micro-benchmark. After all, it would be possible to write a procedural Scala implementation just like the Java one. All this is really measuring is my proficiency with functional Scala.

Remarks

Java version 6; Scala version 2.7.7.

Windows build script:

@REM COMPILE.BAT
@ECHO OFF
javac Spell.java
CALL scalac.bat -classpath . *.scala
javac -cp %SCALA_HOME%\lib\scala-library.jar;. *.java

No comments:

Post a Comment

All comments are moderated