Sunday 15 September 2013

Java: "Did you mean 'stream'?" (Java 8 pre-release)

This post re-implements the "did you mean..?" spell checker from an earlier post using the new Stream type in Java 8.

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.regex.Pattern;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static java.util.stream.Stream.concat;

public class DidYouMean {
  private static final Map<String, Integer> DICTIONARY = loadDictionary();

  private static Stream<String> edit1(String term) {
    Stream<String> delete = IntStream.range(0, term.length())
        .mapToObj(n -> term.substring(0, n) + term.substring(n + 1));
    Stream<String> transpose = IntStream.range(0, term.length() - 1)
        .mapToObj(n -> term.substring(0, n) + term.charAt(n + 1) + term.charAt(n) + term.substring(n + 2));
    Stream<String> replace = IntStream.range(0, term.length())
        .boxed()
        .flatMap(n -> IntStream.rangeClosed('a', 'z').mapToObj(c -> term.substring(0, n) + (char) c + term.substring(n + 1)));
    Stream<String> insert = IntStream.rangeClosed(0, term.length())
        .boxed()
        .flatMap(n -> IntStream.rangeClosed('a', 'z').mapToObj(c -> term.substring(0, n) + (char) c + term.substring(n)));
    return concat(concat(concat(delete, transpose), replace), insert);
  }

  private static Stream<String> edit2(String term) {
    return edit1(term).flatMap(DidYouMean::edit1);
  }

  public static void main(String... args) {
    String result = didYouMean(args[0]);
    if (result == null) {
      System.out.println("Gibberish!");
    } else {
      System.out.println("Did you mean '" + result + "'?");
    }
  }

  public static String didYouMean(String term) {
    String lower = term.toLowerCase(Locale.ENGLISH);
    if (DICTIONARY.containsKey(lower)) return lower;
    Comparator<String> bestMatch = (String s1, String s2) ->
        DICTIONARY.getOrDefault(s1, 0).compareTo(DICTIONARY.getOrDefault(s2, 0));
    String result = edit1(lower).max(bestMatch).get();
    if (DICTIONARY.containsKey(result)) return result;
    result = edit2(term).max(bestMatch).get();
    return DICTIONARY.containsKey(result) ? result : null;
  }

  private static Map<String, Integer> loadDictionary() {
    try (InputStream data = DidYouMean.class.getResourceAsStream("/demo/big.txt");
         Reader reader = new InputStreamReader(data, StandardCharsets.US_ASCII);
         BufferedReader buffer = new BufferedReader(reader)) {
      return countWords(buffer.lines());
    } catch (IOException e) {
      throw new IllegalStateException(e);
    }
  }

  private static Map<String, Integer> countWords(Stream<String> lines) {
    Map<String, Integer> dictionary = new HashMap<>();
    BiFunction<String, Integer, Integer> increment = (key, count) -> count == null ? 1 : ++count;
    Consumer<String> count = word -> dictionary.compute(word.toLowerCase(Locale.ENGLISH), increment);
    Pattern words = Pattern.compile("[^a-zA-Z]++");

    lines.forEach(line -> words.splitAsStream(line).forEach(count));

    return Collections.unmodifiableMap(dictionary);
  }
}

No comments:

Post a Comment

All comments are moderated