Friday, December 28, 2012

Analyzing the Enron Data: Frequency Distribution, Page Rank and Document Clustering


I've been using the Enron Dataset for a couple of projects now, and I figured that it would be interesting to see if I could glean some information out of the data. One can of course simply read the Wikipedia article, but that would be too easy and not as much fun :-).

My focus on this analysis is on the "what" and the "who", ie, what are the important ideas in this corpus and who are the principal players. For that I did the following:

  • Extracted the words from Lucene's inverted index into (term, docID, freq) triples. Using this, I construct a frequency distribution of words in the corpus. Looking at the most frequent words gives us an idea of what is being discussed.
  • Extract the email (from, {to, cc, bcc}) pairs from MongoDB. Using this, I piggyback on Scalding's PageRank implementation to produce a list of emails by page rank. This gives us an idea of the "important" players.
  • Using the triples extracted from Lucene, construct tuples of (docID, termvector), then cluster the documents using KMeans. This gives us an idea of the spread of ideas in the corpus. Originally, the idea was to use Mahout for the clustering, but I ended up using Weka instead.

I also wanted to get more familiar with Scalding beyond the basic stuff I did before, so I used that where I would have used Hadoop previously. The rest of the code is in Scala as usual.

Unfortunately, my Scalding-fu was not strong enough, because I couldn't figure out how to run Scalding jobs in non-local mode, and I was running out of memory for some of the jobs in local mode. Also, Mahout expects its document vectors to be in SequenceFile format, but Scalding does not allow you to write SequenceFiles in local mode. As a result, I converted the document vector file to ARFF format and used Weka for the clustering instead. I have asked about this on the Cascading-Users mailing list.

Frequency Distribution


Here is the the code to extract the (term, docID, freq) tuples from the Lucene index. As you can see the generate method takes three cutoff parameters, the minimum document frequency, the minimum total term frequency and minimum term frequency. I added these cutoffs in because the unfiltered output contained about 700 million triples and was causing the next job (FreqDist) to throw an OutOfMemoryException.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
// Source: src/main/scala/com/mycompany/solr4extras/corpus/Lucene4TermFreq.scala
package com.mycompany.solr4extras.corpus

import java.io.{PrintWriter, FileWriter, File}

import org.apache.lucene.index.{MultiFields, IndexReader, DocsEnum}
import org.apache.lucene.search.DocIdSetIterator
import org.apache.lucene.store.NIOFSDirectory
import org.apache.lucene.util.BytesRef

/**
 * Reads a Lucene4 index (new API) and writes out a
 * text file as (term, docID, frequency_of_term_in_doc).
 * @param indexDir the location of the Lucene index.
 * @param outputFile the output file name.
 * @param minDocFreq terms which are present in fewer
 *        documents than minDocFreq will be ignored. 
 * @param minTTF the minimum Total Term Frequency a 
 *        term must have to be considered for inclusion.
 * @param minTermFreq the minimum term frequency within
 *        a document so the term is included.
 */
class Lucene4TermFreq(indexDir: String) {

  def generate(outputFile: String, minDocs: Int,
      minTTF: Int, minTermFreq: Int): Unit = {
    val reader = IndexReader.open(
      new NIOFSDirectory(new File(indexDir), null))
    val writer = new PrintWriter(new FileWriter(outputFile), true)
    val terms = MultiFields.getTerms(reader, "body").iterator(null)
    var term: BytesRef = null
    var docs: DocsEnum = null
    do {
      term = terms.next
      if (term != null) {
        val docFreq = terms.docFreq
        val ttf = terms.totalTermFreq
        if (docFreq > minDocs && ttf > minTTF) {
          docs = terms.docs(null, docs)
          var docID: Int = -1
          do {
            docID = docs.nextDoc
            if (docID != DocIdSetIterator.NO_MORE_DOCS) {
              val termFreq = docs.freq
              if (termFreq > minTermFreq)
                writer.println("%s\t%d\t%d".format(
                  term.utf8ToString, docID, docs.freq))
            }
          } while (docID != DocIdSetIterator.NO_MORE_DOCS)
        }
      }
    } while (term != null)
    writer.flush
    writer.close
    reader.close
  }
}

To decide the cutoffs, I first plotted the sorted total term frequencies of all words in the corpus, that results in the Zipf distribution shown below, with the last 20,000 terms contributing to most of the occurrence count. Cutting the TTFs off at about 1,000 occurrences, and setting the minimum document frequency to 5 (term must exist in at least 5 documents) and minimum term frequency to 10 (term must exist at least 10 times in a document to be counted) resulted in a more manageable number of about 1.27 million triples.


Here is the code to compute the Frequency Distribution of the terms in the corpus.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
// Source: src/main/scala/com/mycompany/solr4extras/corpus/FreqDist.scala
package com.mycompany.solr4extras.corpus

import com.twitter.scalding.{Tsv, Job, Args}

import cascading.pipe.joiner.LeftJoin

/**
 * Reads input of the form (term docID freq), removes stopword
 * terms based on a stop word list, sums up the term frequency
 * across docs and outputs the term frequency counts sorted by
 * count descending as (term count).
 * NOTE: this can also be done directly from Lucene using 
 * totalTermFreq.
 */
class FreqDist(args: Args) extends Job(args) {

  val stopwords = Tsv(args("stopwords"), ('stopword)).read
  val input = Tsv(args("input"), ('term, 'docID, 'freq))
  val output = Tsv(args("output"))
  input.read.
    joinWithSmaller('term -> 'stopword, stopwords, joiner = new LeftJoin).
    filter('stopword) { stopword: String => 
      (stopword == null || stopword.isEmpty) 
    }.
    groupBy('term) { _.sum('freq) }.
    groupAll { _.sortBy('freq).reverse }.
    write(output)
}

The top 100 words (and their raw frequencies) from the resulting frequency distribution are shown below:

enron (1349349), ect (1133513), hou (578328), subject (446732), pm (388238), http (325908), power (309063), cc (303380), enron.com (290452), energy (286475), corp (262331), message (245042), mail (234409), time (217656), gas (216405), company (189878), market (181420), information (180539), ees (176279), original (170590), call (153876), california (152559), business (148100), forwarded (145700), day (138167), na (132233), td (132135), font (132130), price (131493), week (131246), state (130245), year (127376), email (124309), attached (121637), houston (119962), image (118982), john (113327), meeting (111627), agreement (111621), mark (111518), deal (108930), make (106030), group (105643), trading (105130), questions (103417), enron_development (102359), contact (97935), date (96189), back (95138), million (93875), services (93825), work (92014), jeff (89695), today (89682), report (89632), electricity (88541), service (88489), monday (87089), prices (85148), free (83838), friday (83750), credit (83487), contract (82981), system (80955), financial (80927), good (80489), review (78720), fax (78219), management (76054), companies (75371), david (74984), news (74666), number (73917), file (73378), jones (72288), thursday (72182), order (71953), list (71845), send (71824), forward (71651), tuesday (71515), office (70850), october (70826), based (70518), enronxgate (69905), wednesday (69339), risk (69091), change (68911), received (68688), mike (68609), issues (68449), team (67302), bill (67289), click (66993), plan (66953), customers (66047), communications (65620), november (65478), phone (65434), provide (65347)

Page Rank


The next thing I wanted to find was to somehow rank the players (the email authors) in importance. In any organization, people tend to email at their own level, but include their immediate bosses on the CC or BCC as a form of CYA. So top management would get relatively few (but highly ranked) emails from middle management, and middle management would get many (low ranked) emails from their underlings. Sorting the email authors by descending order of page rank should tell us about he major players.

The input data comes from the MongoDB database I populated for my previous project. I had to go through a few hoops because the data is encrypted, but the net result is that I end up with two files, the first containing the (from_email, from_id) tuples and the second containing (from_id, {to_id, cc_id, bcc_id}) tuples. The reason I generated two files is because the Scalding distribution comes with a PageRank implementation which expects the input data to be numeric. Here is the extraction code:

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// Source: src/main/scala/com/mycompany/solr4extras/corpus/MongoEmailPairs.scala
package com.mycompany.solr4extras.corpus

import java.io.{PrintWriter, FileWriter, File}
import java.util.concurrent.atomic.AtomicInteger

import org.apache.commons.codec.binary.Hex
import org.apache.lucene.index.IndexReader
import org.apache.lucene.store.NIOFSDirectory

import com.mongodb.casbah.Imports.{wrapDBObj, wrapDBList, MongoDBObject, MongoConnection, BasicDBList}
import com.mycompany.solr4extras.secure.CryptUtils

class MongoEmailPairs(host: String, port: Int, db: String, 
    indexDir: String) {

  val conn = MongoConnection(host, port)
  val emails = conn(db)("emails")
  val users = conn(db)("users")
  val reader = IndexReader.open(
    new NIOFSDirectory(new File(indexDir), null))

  def generate(refFile: String, outputFile: String): Unit = {
    val counter = new AtomicInteger(0)
    val userKeys = users.find().map(user => 
      user.as[String]("email") -> 
      (Hex.decodeHex(user.as[String]("key").toCharArray), 
      Hex.decodeHex(user.as[String]("initvector").toCharArray),
      counter.incrementAndGet)).toMap
    // write out dictionary file for reference
    val refWriter = new PrintWriter(new FileWriter(new File(refFile)), true)
    userKeys.map(user =>
      refWriter.println("%s\t%d".format(user._1, user._2._3))
    )
    refWriter.flush
    refWriter.close
    // write out main file as required by PageRank
    val dataWriter = new PrintWriter(new FileWriter(new File(outputFile)), true)
    val numdocs = reader.numDocs
    var i = 0
    while (i < numdocs) {
      val doc = reader.document(i)
      val messageID = doc.get("message_id").asInstanceOf[String]
      val author = doc.get("from").asInstanceOf[String]
      val mongoQuery = MongoDBObject("message_id" -> messageID)
      val cur = emails.find(mongoQuery)
      emails.findOne(mongoQuery) match {
        case Some(email) => {
          try {
            val from = CryptUtils.decrypt(
              Hex.decodeHex(email.as[String]("from").toCharArray), 
              userKeys(author)._1, userKeys(author)._2)
            val fromId = userKeys(from)._3
            val targets = 
              (try {
                email.as[BasicDBList]("to").toList  
              } catch {
                case e: NoSuchElementException => List()
              }) ++
              (try {
                email.as[BasicDBList]("cc").toList
              } catch {
                case e: NoSuchElementException => List()
              }) ++
              (try {
                email.as[BasicDBList]("bcc").toList
              } catch {
                case e: NoSuchElementException => List()
              })
            targets.map(target => {
              val targetEmail = CryptUtils.decrypt(Hex.decodeHex(
                target.asInstanceOf[String].toCharArray), 
                userKeys(author)._1, userKeys(author)._2).trim
              val targetEmailId = userKeys(targetEmail)._3
              dataWriter.println("%d\t%d".format(fromId, targetEmailId))
            })
          } catch {
            // TODO: BadPaddingException, likely caused by, 
            // problems during population. Fix, but skip for now
            case e: Exception => println("error, skipping")
          }
        }
        case None => // skip
      }
      i = i + 1
    }
    dataWriter.flush
    dataWriter.close
    reader.close
    conn.close
  }
}

We then create our own subclass of PageRank and override the initialize method that produces a Source tap for PageRank from the email pairs we just generated from MongoDB. The work that we do here is to group by the from_id and aggregate the to_ids into a comma-separated list, then add a column with the initial value of the page rank (1.0), and rename the columns so it is usable by the parent class. Once the job has finished, we post-process the output to produce a list of from email addresses and their associated page rank in descending rank order. Here is the code for these two classes:

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
// Source: src/main/scala/com/mycompany/solr4extras/corpus/MailRank.scala
package com.mycompany.solr4extras.corpus

import com.twitter.scalding.examples.PageRank
import com.twitter.scalding.{Tsv, Job, Args}

import cascading.pipe.joiner.LeftJoin
import cascading.pipe.Pipe

/**
 * Converts data generated by MongoEmailPairs of the 
 * form: (from_id, {to_id|cc_id|bcc_id}) to the format
 * required by PageRank, ie (from_id, List(to_ids), pagerank)
 */
class MailRank(args: Args) extends PageRank(args) {

  override def initialize(nodeCol: Symbol, neighCol: Symbol, 
      pageRank: Symbol): Pipe = {
    val input = Tsv(args("input"), ('from, 'to))
    input.read.
      groupBy('from) { _.toList[String]('to -> 'tos) }.
      map('tos -> 'tosf) { tos: List[String] => 
        tos.foldLeft("")(_ + "," + _).substring(1) }.
      map('from -> ('from, 'prob)) { from: String => 
        (from, 1.0) 
      }.project('from, 'tosf, 'prob).
      mapTo((0, 1, 2) -> (nodeCol, neighCol, pageRank)) {
        input : (Long, String, Double) => input
      }
  }
}

/**
 * Converts the format returned by PageRank, ie:
 * (from_id, List(to_id), final_pagerank) to 
 * (from_email, final_pagerank) sorted by pagerank 
 * descending.
 */
class MailRankPostProcessor(args: Args) extends Job(args) {
  
  val input = Tsv(args("input"), ('from, 'tos, 'rank))
  val output = Tsv(args("output"))

  val reference = Tsv(args("reference"), ('ref_email, 'ref_from)).read
  input.read.
    project('from, 'rank).
    joinWithSmaller('from -> 'ref_from, reference, joiner = new LeftJoin).
    project('ref_email, 'rank).
    groupAll { _.sortBy('rank).reverse }.
    write(output)
}

Here is the list of the top 15 email addresses which had the highest page rank. If you followed the Enron trial, then you may recognize a few names here:

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
vince.kaminski@enron.com    129.6456344199462
sara.shackleton@enron.com   123.65630393847148
louise.kitchen@enron.com    117.4938083777274
jeff.dasovich@enron.com     116.83443565190146
tana.jones@enron.com        113.12379367608482
mark.taylor@enron.com       109.67593395984632
sally.beck@enron.com        108.4905670421611
ebass@enron.com             103.54919812912469
jeff.skilling@enron.com     100.57545577565519
steven.kean@enron.com        99.69270847011283
john.lavorato@enron.com      90.16447940199485
gerald.nemec@enron.com       88.4001157213643
kenneth.lay@enron.com        88.1467699737448
richard.shapiro@enron.com    82.10524578705625
kay.mann@enron.com           69.19222780384432

Document Clustering


Finally, I decided to cluster the documents in order to find if there were multiple topics that were being discussed in the corpus. The input to this process is the (term, docID, freq) triples that we generated from Lucene, and the output is a set of (docID, {term frequency vector}) tuples. The sets up partial vectors and aggregates them together to form document vectors using Mahout's SequentialAccessSparseVector class. I found post on Software Anatomy very useful when writing this code - much of the DocVector class below is based on the code shown there.

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
// Source: src/main/scala/com/mycompany/solr4extras/corpus/DocVector.scala
package com.mycompany.solr4extras.corpus

import java.io.{PrintWriter, FileWriter, File}

import scala.collection.mutable.ListBuffer
import scala.io.Source

import org.apache.mahout.math.{VectorWritable, SequentialAccessSparseVector}

import com.twitter.scalding.{Tsv, TextLine, Job, Args}

import cascading.pipe.joiner.LeftJoin

/**
 * Reads input of the form (term freq) of most frequent terms,
 * and builds a dictionary file. Using this file, creates a 
 * collection of docID to sparse doc vector mappings of the 
 * form (docID, {termID:freq,...}).
 */
class DocVector(args: Args) extends Job(args) {

  val input = Tsv(args("input"), ('term, 'docID, 'freq))
  val termcounts = TextLine(args("termcounts"))
  
  val dictOutput = Tsv(args("dictionary"))
  val output = Tsv(args("docvector"))
  
  // (term freq) => (term num)
  val dictionary = termcounts.read.
    project('num, 'line).
    map('line -> 'word) { line: String => line.split('\t')(0) }.
    project('word, 'num)

  // input: (term, docID, freq)
  // join with dictionary ond write document as (docId, docvector) 
  input.read.
    joinWithSmaller('term -> 'word, dictionary, joiner = new LeftJoin).
    filter('word) { word: String => (!(word == null || word.isEmpty)) }.
    project('docID, 'num, 'freq).
    map(('docID, 'num, 'freq) -> ('docId, 'pvec)) { 
      doc: (String, Int, Int) =>
        val pvec = new SequentialAccessSparseVector(
          args("vocabsize").toInt)
        pvec.set(doc._2, doc._3)
      (doc._1, new VectorWritable(pvec))
    }.
    groupBy('docId) { 
      group => group.reduce('pvec -> 'vec) {
        (left: VectorWritable, right: VectorWritable) => 
          new VectorWritable(left.get.plus(right.get).normalize)
    }}.
    write(output)
    
    // save the dictionary as (term, idx)    
    dictionary.write(dictOutput)
}

/**
 * Converts the Document Vector file to an ARFF file for 
 * consumption by Weka.
 */
class DocVectorToArff {
  
  def generate(input: String, output: String, 
      numDimensions: Int): Unit = {
    val writer = new PrintWriter(new FileWriter(new File(output)), true)
    // header
    writer.println("@relation docvector\n")
    (1 to numDimensions).map(n => 
      writer.println("@attribute vec" + n + " numeric"))
    writer.println("\n@data\n")
    // body
    Source.fromFile(new File(input)).getLines.foreach(line => { 
      writer.println(line.split('\t')(1).replaceAll(":", " "))
    })
    writer.flush
    writer.close
  }
}

/**
 * Reads output from Weka Explorer SimpleKMeans run (slightly
 * modified to remove header information) to produce a list
 * of top N words from each cluster.
 */
class WekaClusterDumper {
  
  def dump(input: String, dictionary: String, 
      output: String, topN: Int): Unit = {
    
    // build up map of terms from dictionary
    val dict = Source.fromFile(new File(dictionary)).getLines.
      map(line => { 
        val cols = line.split("\t")
        cols(1).toInt -> cols(0)
    }).toMap
    // build up elements list from weka output
    var clusterScores = new Array[ListBuffer[(Int,Double)]](5)
    Source.fromFile(new File(input)).getLines.
      foreach(line => {
        val cols = line.split("\\s+")
        val idx = cols(0).toInt - 1
        val scores = cols.slice(2, 7) 
        (0 to 4).foreach(i => 
          if (scores(i).toDouble > 0.0D) {
            if (clusterScores(i) == null)
              clusterScores(i) = new ListBuffer[(Int,Double)]
            clusterScores(i) += Tuple(idx, scores(i).toDouble)
        })
    })
    // sort each clusterScore by score descending and get the
    // corresponding words from the dictionary by idx
    val writer = new PrintWriter(new FileWriter(new File(output)), true)
    var i = 0
    clusterScores.foreach(clusterScore => {
      writer.println("Cluster #" + i)
      clusterScore.toList.sort(_._2 > _._2).
        slice(0, topN).map(tuple => {
          val word = dict(tuple._1)
          writer.println("  " + word + " (" + tuple._2 + ")")
        })
      i = i + 1
    })
    writer.flush
    writer.close
  }
}

As I mentioned above, the output of DocVector was supposed to be passed into Mahout's KMeans driver (and Canopy driver for the initial centroids) but I could not generate Sequence files which Mahout expects, so I converted the DocVector output to an ARFF file so I could pass it into Weka.

I then ran SimpleKMeans requesting 5 clusters with 10 iterations, and the first attempt resulted in an OutOfMemoryException. Since I have a pretty bad-ass MacBook Pro with 8GB RAM (okay, it was bad-ass when I bought it 3 years ago), I upped the default -Xmx for Weka from 256MB to 2GB (this is in the Info.plist file under the /Applications/weka directory for those of you who use Macs), and life was good again.

In any case, after a while, Weka dumped out the results of the computation to the Explorer console, from which I copy-pasted it and passed it through my WekaClusterDumper class. This class parses the Weka clustering output (minus the headers which I removed manually) to print the top N words in each cluster. Since Weka does not know the actual terms being clustered (only their position in the dictionary file generated by DocVector), the Cluster dumper uses this file to look up the actual terms in each cluster. Here is the output of the dump.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
Cluster #0
  ect (13.3978)
  image (4.639)
  enron.com (3.7695)
Cluster #1
  enron (20.7614)
Cluster #2
  http (0.3507)
  enron.com (0.309)
  enron_development (0.2871)
  mail (0.1964)
  ees (0.1615)
  image (0.1558)
  hou (0.1097)
  message (0.1039)
  ect (0.0909)
  energy (0.0892)
Cluster #3
  mail (56.5979)
Cluster #4
  study (12.2941)

As you can see, according the clustering, there does not seem to be too much variety in the discussion going on in the Enron dataset.

Finally, I used a single object to call the classes described above from sbt (using "sbt run"). The whole process was very interactive, so don't run the code as is, it will fail. I ran each step individually, some multiple times, and commented out the previous blocks as I went forward.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// Source: src/main/scala/com/mycompany/solr4extras/corpus/CorpusAnalyzer.scala
package com.mycompany.solr4extras.corpus

import com.twitter.scalding.Args

object CorpusAnalyzer extends App {

  //////////// document clustering ///////////
  
  (new Lucene4TermFreq("/Users/sujit/Downloads/apache-solr-4.0.0/example/solr/collection1/data/index")).
    generate("data/input/corpus_freqs.txt", 50, 1000, 10)

  (new FreqDist(Args(List(
    "--local", "", 
    "--input", "data/input/corpus_freqs.txt", 
    "--output", "data/output/freq_dist.txt",
    "--stopwords", "/Users/sujit/Downloads/apache-solr-4.0.0/example/solr/collection1/conf/stopwords.txt")))).
    run

  //////// mail rank ////////
  
  (new MongoEmailPairs("localhost", 27017, "solr4secure", 
    "/Users/sujit/Downloads/apache-solr-4.0.0/example/solr/collection1/data/index")).
    generate("data/input/email_refs.txt",
    "data/input/email_pairs.txt")
  
  (new MailRank(Args(List(
    "--local", "",
    "--input", "data/input/email_pairs.txt",
    "--output", "data/output/mailrank.txt",
    "--iterations", "10")))).
    run

  (new MailRankPostProcessor(Args(List(
    "--local", "",
    "--input", "data/output/mailrank.txt",
    "--reference", "data/input/email_refs.txt",
    "--output", "data/output/mailrank_final.txt")))).run

  ////////////// clustering terms ///////////////
    
  (new DocVector(Args(List(
    "--local", "",
    "--input", "data/input/corpus_freqs.txt",
    "--termcounts", "data/input/freq_words.txt",
    "--vocabsize", "1326", // cat freq_words | cut -f1 | sort | uniq | wc
    "--dictionary", "data/output/dictionary.txt",
    "--docvector", "data/output/docvector.txt")))).
    run

  (new DocVectorToArff()).generate(
    "/Users/sujit/Projects/solr4-extras/data/output/docvector.txt", 
    "data/output/docvector.arff", 1326)

  (new WekaClusterDumper()).dump(
    "/Users/sujit/Projects/solr4-extras/data/output/weka_cluster_output.txt",
    "/Users/sujit/Projects/solr4-extras/data/output/dictionary.txt",
    "data/output/cluster.dump",10)
    
}

Well, this is all I have for today. The source code for this stuff is also available on my GitHub project page if you want to play around with it on your own. Hope you enjoyed it. If I don't get a chance to post again before next year (unlikely given that there are just 3 more days left and my mean time between posts is now about 14 days), be safe and have a very Happy New Year. Heres hoping for a lot of fun in 2013 (the International Year of Statistics).


Saturday, December 15, 2012

Searching an Encrypted Document Collection with Solr4, MongoDB and JCE


A while back, someone asked me if it was possible to make an encrypted document collection searchable through Solr. The use case was patient records - the patient is the owner of the records, and the only person who can search through them, unless he temporarily grants permission to someone else (for example his doctor) for diagnostic purposes. I couldn't come up with a good way of doing it off the bat, but after some thought, came up with a design that roughly looked like the picture below:


I just finished the M101 - MongoDB for Developers online class conducted by 10gen, and I had been meaning to check out the recently released Solr4, so this seemed like a good opportunity to implement the design and gain some hands-on experience with MongoDB and Solr4 at the same time. So thats what I did, and the rest of this post will describe the implementation. Instead of patient records, I used emails from the Enron dataset, which I had on hand from a previous project.

Essentially the idea is that, during indexing, we store all but two of the fields as unstored in the Lucene index, so we can only search on them. The two fields that are not unstored (ie stored) are unique IDs into a database (MongoDB in this case) into which we store the encrypted version of the document - a unique ID identifying the user whose key is used to encrypt the document (in this case email address), and a unique ID identifying the document itself (in this case the messageID).

The final part of the puzzle is a custom Solr component that is configured at the end of a standard SearchHandler chain. This reads the response from the previous controller (a DocList containing (docID, score) pairs), calls into the Lucene index to get the corresponding messageIDs for the page slice, then calls into MongoDB to get the encrypted documents, decrypts each message using the user's key (which is also retrieved from MongoDB using the userID.

For encryption, I am using a symmetric block cipher (AES). For each user we generate a random 16 byte key and a corresponding initialization vector which is henceforth used for encrypting documents belonging to that user. Both keys are stored in MongoDB. When users search their documents, the email address is passed in as a filter, which is used to lookup the AES key and initialization vector and decrypt the document to show in Solr.

An extension (which I haven't done) is for users to "allow" access to their records to someone else. I thought of using assymetric key encryption (RSA) for encrypting the user's AES keys, so one could do a "circle of trust"-like model - however, since the keyring is centralized in MongoDB, it just seemed like additional complexity without any benefit, so I passed on that. Arguably, of course, this whole thing could have been avoided altogether by putting the authentication in a system outside Solr, and depending on your requirements that may well be a better approach. But I was trying to design an answer to the question :-).

So anyway, on to the code. The code is in Scala, in keeping with my plan to replace Java with Scala in my personal projects whenever possible, and as I get better at it, my Scala code is starting to look more like Scala than like Java. I realize that the main audience for my posts on stuff such as Solr customization are Java programmers, so I apologize if the code is somewhat hard (although not impossible if you try a bit) to read.

Configuration


I have a single configuration file that provides the location of the MongoDB host and the Solr server to the Indexing code. I reuse the same file (for convenience) to configure the custom Solr search component later. It looks like this:

1
2
3
4
5
6
7
8
9
# conf/secure/secure.properties
# Configuration file for EncryptingIndexer
num.workers=2
data.root.dir=/Users/sujit/Downloads/enron_mail_20110402/maildir
mongo.host=localhost
mongo.port=27017
mongo.db=solr4secure
solr.server=http://localhost:8983/solr/collection1/
default.fl=message_id,from,to,cc,bcc,date,subject,body

For testing, I am using Solr's start.jar and its examples/solr directory. I modified the schema.xml (in solr/collections1/conf) to add my own fields, here is the relevant snippet.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
...
  <fields>
    <field name="message_id" type="string" indexed="true" stored="true" 
        required="true"/>
    <field name="from" type="string" indexed="true" stored="true"/>
    <field name="to" type="string" indexed="true" stored="false" 
        multiValued="true"/>
    <field name="cc" type="string" indexed="true" stored="false" 
        multiValued="true"/>
    <field name="bcc" type="string" indexed="true" stored="false" 
        multiValued="true"/>
    <field name="date" type="date" indexed="true" stored="false"/>
    <field name="subject" type="text_general" indexed="true" stored="false"/>
    <field name="body" type="text_general" indexed="true" stored="false"/>
    <field name="_version_" type="long" indexed="true" stored="true"/>
  </fields>
  ...
  <uniqueKey>message_id</uniqueKey>
  ...

MongoDB is schema-free, so no specific configuration is needed for this. However, make sure to create unique indexes on emails.message_id and emails.email once the indexing (described below) is done, otherwise your searches will be very slow.

Indexing


The EncryptingIndexer requires that both MongoDB daemon (mongod) and Solr (java -jar start.jar) are up. Once that is done, you can start the job using "sbt run". The design of the indexer is based on Akka Actors, very similar to one I built earlier. The master actor distributes the job to two worker actors (corresponding to the two CPUs on my laptop) and each worker reads and parses the input email file, then creates/retrieves the AES key corresponding to the author of the email, uses that key to encrypt the document into MongoDB, and publishes the document to Solr. Because of the schema definition above, most of the document is written out into unstored fields. Here is the code.

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
// Source: src/main/scala/com/mycompany/solr4extras/secure/EncryptingIndexer.scala
package com.mycompany.solr4extras.secure

import java.io.File
import java.text.SimpleDateFormat

import scala.Array.canBuildFrom
import scala.collection.immutable.Stream.consWrapper
import scala.collection.immutable.HashMap
import scala.io.Source

import org.apache.solr.client.solrj.impl.HttpSolrServer
import org.apache.solr.common.SolrInputDocument

import akka.actor.actorRef2Scala
import akka.actor.{Props, ActorSystem, ActorRef, Actor}
import akka.routing.RoundRobinRouter

object EncryptingIndexer extends App {

  val props = properties(new File("conf/secure/secure.properties"))
  val system = ActorSystem("Solr4ExtrasSecure")
  val reaper = system.actorOf(Props[IndexReaper], name="reaper")
  val master = system.actorOf(Props(new IndexMaster(props, reaper)), 
    name="master")
  master ! StartMsg
  
  ///////////////// actors and messages //////////////////
  sealed trait AbstractMsg
  case class StartMsg extends AbstractMsg
  case class IndexMsg(file: File) extends AbstractMsg
  case class IndexedMsg(status: Int) extends AbstractMsg
  case class StopMsg extends AbstractMsg
  
  /**
   * The master actor starts up the workers and the reaper,
   * then populates the input queue for the IndexWorker actors.
   * It also handles counters to track the progress of the job
   * and once the work is done sends a message to the Reaper
   * to shut everything down.
   */
  class IndexMaster(props: Map[String,String], reaper: ActorRef)
      extends Actor {

    val mongoDao = new MongoDao(props("mongo.host"),
        props("mongo.port").toInt,
        props("mongo.db"))
    val solrServer = new HttpSolrServer(props("solr.server"))
    val numWorkers = props("num.workers").toInt
    val router = context.actorOf(
      Props(new IndexWorker(mongoDao, solrServer)).
      withRouter(RoundRobinRouter(numWorkers)))
    
    var nreqs = 0
    var nsuccs = 0
    var nfails = 0
    
    override def receive = {
      case StartMsg => {
        val files = walk(new File(props("data.root.dir"))).
          filter(x => x.isFile)
        for (file <- files) {
          println("adding " + file + " to worker queue")
          nreqs = nreqs + 1
          router ! IndexMsg(file)
        }
      }
      case IndexedMsg(status) => {
        if (status == 0) nsuccs = nsuccs + 1 else nfails = nfails + 1
        val processed = nsuccs + nfails
        if (processed % 100 == 0) {
          solrServer.commit
          println("Processed %d/%d (success=%d, failures=%d)".
            format(processed, nreqs, nsuccs, nfails))
        }
        if (nreqs == processed) {
          solrServer.commit
          println("Processed %d/%d (success=%d, failures=%d)".
            format(processed, nreqs, nsuccs, nfails))
          reaper ! StopMsg
          context.stop(self)
        }
      }
    }
  }
  
  /**
   * These actors do the work of parsing the input file, encrypting
   * the content and writing the encrypted data to MongoDB and the
   * unstored data to Solr.
   */
  class IndexWorker(mongoDao: MongoDao, solrServer: HttpSolrServer) 
      extends Actor {
    
    override def receive = {
      case IndexMsg(file) => {
        val doc = parse(Source.fromFile(file))
        try {
          mongoDao.saveEncryptedDoc(doc)
          addToSolr(doc, solrServer)
          sender ! IndexedMsg(0)
        } catch {
          case e: Exception => {
            e.printStackTrace
            sender ! IndexedMsg(-1)
          }
        }
      }
    }
  }
  
  /**
   * The Reaper shuts down the system once everything is done.
   */
  class IndexReaper extends Actor {
    override def receive = {
      case StopMsg => {
        println("Shutting down Indexer")
        context.system.shutdown
      }
    }
  }
  
  ///////////////// global functions /////////////////////
  
  /**
   * Add the document, represented as a Map[String,Any] name-value
   * pairs to the Solr index. Note that the schema sets all these
   * values to tokenized+unstored, so all we have in the index is
   * the inverted index for these fields.
   * @param doc the Map[String,Any] set of field key-value pairs.
   * @param server a reference to the Solr server.
   */
  def addToSolr(doc: Map[String,Any], server: HttpSolrServer): Unit = {
    val solrdoc = new SolrInputDocument()
    doc.keys.map(key => doc(key) match {
      case value: String => 
        solrdoc.addField(normalize(key), value.asInstanceOf[String])
      case value: Array[String] => 
        value.asInstanceOf[Array[String]].
          map(v => solrdoc.addField(normalize(key), v)) 
    })
    server.add(solrdoc)
  }

  /**
   * Normalize keys so they can be used without escaping in
   * Solr and MongoDB.
   * @param key the un-normalized string.
   * @return the normalized key (lowercased and space and 
   *         hyphen replaced by underscore).
   */
  def normalize(key: String): String = 
    key.toLowerCase.replaceAll("[ -]", "_")
    
  /**
   * Parse the email file into a set of name value pairs.
   * @param source the Source object representing the file.
   * @return a Map of name value pairs.
   */
  def parse(source: Source): Map[String,Any] = {
    parse0(source.getLines(), HashMap[String,Any](), false).
      filter(x => x._2 != null)
  }
  
  private def parse0(lines: Iterator[String], 
      map: Map[String,Any], startBody: Boolean): 
      Map[String,Any] = {
    if (lines.isEmpty) map
    else {
      val head = lines.next()
      if (head.trim.length == 0) parse0(lines, map, true)
      else if (startBody) {
        val body = map.getOrElse("body", "") + "\n" + head
        parse0(lines, map + ("body" -> body), startBody)
      } else {
        val split = head.indexOf(':')
        if (split > 0) {
          val kv = (head.substring(0, split), head.substring(split + 1))
          val key = kv._1.map(c => if (c == '-') '_' else c).trim.toLowerCase
          val value = kv._1 match {
            case "Date" => 
              formatDate(kv._2.trim)
            case "Cc" | "Bcc" | "To" => 
              kv._2.split("""\s*,\s*""")
            case "Message-ID" | "From" | "Subject" | "body" => 
              kv._2.trim
            case _ => null
          }
          parse0(lines, map + (key -> value), startBody)
        } else parse0(lines, map, startBody)
      }
    }
  }
  
  def formatDate(date: String): String = {
    lazy val parser = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss")
    lazy val formatter = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss'Z'")
    formatter.format(parser.parse(date.substring(0, date.lastIndexOf('-') - 1)))
  }

  def properties(conf: File): Map[String,String] = {
    Map() ++ Source.fromFile(conf).getLines().toList.
      filter(line => (! (line.isEmpty || line.startsWith("#")))).
      map(line => (line.split("=")(0) -> line.split("=")(1)))
  }

  def walk(root: File): Stream[File] = {
    if (root.isDirectory)
      root #:: root.listFiles.toStream.flatMap(walk(_))
    else root #:: Stream.empty
  }
}

You will notice that it depends on the MongoDao class, which in turn depends on the CryptUtils object. They are shown below. The Solr component also depends on these. Here is the code for the MongoDao class. I am using the Mongo-Scala driver Casbah.

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Source: src/main/scala/com/mycompany/solr4extras/secure/MongoDao.scala
package com.mycompany.solr4extras.secure

import scala.Array.canBuildFrom
import scala.collection.JavaConversions.asScalaSet

import org.apache.commons.codec.binary.Hex

import com.mongodb.casbah.Imports._
import com.mycompany.solr4extras.secure.CryptUtils._

class MongoDao(host: String, port: Int, db: String) {

  val conn = MongoConnection(host, port)
  val users = conn(db)("users")
  val emails = conn(db)("emails")
  
  /**
   * Called from the indexing subsystem. The index document, 
   * represented as a Map of name-value pairs, is sent to this
   * method to be encrypted and persisted to a MongoDB collection.
   * @param doc the index document to be saved.
   */
  def saveEncryptedDoc(doc: Map[String,Any]): Unit = {
    val email = doc.get("from") match {
      case Some(x) => {
        val keypair = getKeys(x.asInstanceOf[String])
        val builder = MongoDBObject.newBuilder
        // save the message_id unencrypted since we will
        // need to look up using this
        builder += "message_id" -> doc("message_id")
        doc.keySet.filter(fn => (! fn.equals("message_id"))).
          map(fn => doc(fn) match {
          case value: String => {
            val dbval = Hex.encodeHexString(encrypt(
              value.asInstanceOf[String].getBytes, 
              keypair._1, keypair._2))
            builder += fn -> dbval
          }
          case value: Array[String] => { 
            val dbval = value.asInstanceOf[Array[String]].map(x => 
              Hex.encodeHexString(encrypt(
              x.getBytes, keypair._1, keypair._2)))
            builder += fn -> dbval
          }
        })
        emails.save(builder.result)
      }
      case None => 
        throw new Exception("Invalid Email, no sender, skip")
    } 
  }
  
  /**
   * Implements a pass-through cache. If the email can be found
   * in the cache, then it is returned from there. If not, the
   * MongoDB database is checked. If found, its returned from 
   * there, else it is created and stored in the database and map.
   * @param email the email address of the user.
   * @return pair of (key, initvector) for the user.
   */
  def getKeys(email: String): (Array[Byte], Array[Byte]) = {
    this.synchronized {
      val query = MongoDBObject("email" -> email)
      users.findOne(query) match {
        case Some(x) => {
          val keys = (Hex.decodeHex(x.as[String]("key").toCharArray), 
            Hex.decodeHex(x.as[String]("initvector").toCharArray))
          keys
        }
        case None => {
          val keys = CryptUtils.keys
          users.save(MongoDBObject(
            "email" -> email,
            "key" -> Hex.encodeHexString(keys._1),
            "initvector" -> Hex.encodeHexString(keys._2)
          ))
          keys
        }
      }
    }
  }
  
  /**
   * Called from the Solr DecryptComponent with list of docIds.
   * Retrieves the document corresponding to each id in the list
   * from MongoDB and returns it as a List of Maps, where each
   * document is represented as a Map of name and decrypted value 
   * pairs.
   * @param email the email address of the user, used to retrieve 
   *              the encryption key and init vector for the user.
   * @param fields the list of field names to return.
   * @param ids the list of docIds to return.
   * @return a List of Map[String,Any] documents.              
   */
  def getDecryptedDocs(email: String, fields: List[String], 
      ids: List[String]): List[Map[String,Any]] = {
    val (key, iv) = getKeys(email)
    val fl = MongoDBObject(fields.map(x => x -> 1))
    val cursor = emails.find("message_id" $in ids, fl)
    cursor.map(doc => getDecryptedDoc(doc, key, iv)).toList.
      sortWith((x, y) => 
        ids.indexOf(x("message_id")) < ids.indexOf(y("message_id")))
  }
  
  /**
   * Returns a document returned from MongoDB (as a DBObject)
   * decrypts it with the key and init vector, and returns the
   * decrypted object as a Map of name-value pairs.
   * @param doc the DBObject representing a single document.
   * @param key the byte array representing the AES key.
   * @param iv the init vector created at key creation.
   * @return a Map[String,Any] of name-value pairs, where values
   *         are decrypted.
   */
  def getDecryptedDoc(doc: DBObject, 
      key: Array[Byte], iv: Array[Byte]): Map[String,Any] = {
    val fieldnames = doc.keySet.toList.filter(fn => 
      (! "message_id".equals(fn)))
    val fieldvalues = fieldnames.map(fn => doc(fn) match {
      case value: String =>
        decrypt(Hex.decodeHex(value.asInstanceOf[String].toCharArray), 
          key, iv)
      case value: BasicDBList =>
        value.asInstanceOf[BasicDBList].elements.toList.
          map(v => decrypt(Hex.decodeHex(v.asInstanceOf[String].toCharArray), 
          key, iv))
      case _ =>
        doc(fn).toString
    })
    Map("message_id" -> doc("message_id")) ++ 
      fieldnames.zip(fieldvalues)
  }
}

And finally, the CryptUtils object, which exposes several static methods to generate AES keys, and encrypt and decrypt byte streams. I didn't know much about the Java Cryptography Extension (JCE) before this - still don't know enough to work on it exclusively - but I found this multi-page tutorial on JavaMex and this post on CodeCrack very helpful in moving me along.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
// Source: src/main/scala/com/mycompany/solr4extras/secure/CryptUtils.scala
package com.mycompany.solr4extras.secure

import java.security.SecureRandom

import javax.crypto.spec.{SecretKeySpec, IvParameterSpec}
import javax.crypto.Cipher

/**
 * Methods for generating random symmetric encryption keys,
 * and encrypting and decrypting text using these keys.
 */
object CryptUtils {

  def keys(): (Array[Byte], Array[Byte]) = {
    val rand = new SecureRandom
    val key = new Array[Byte](16)
    val iv = new Array[Byte](16)
    rand.nextBytes(key)
    rand.nextBytes(iv)
    (key, iv)
  }

  def encrypt(data: Array[Byte], 
      key: Array[Byte], iv: Array[Byte]): Array[Byte] = {
    val keyspec = new SecretKeySpec(key, "AES")
    val cipher = Cipher.getInstance("AES/CBC/PKCS5PADDING")
    if (iv == null) 
      cipher.init(Cipher.ENCRYPT_MODE, keyspec)
    else
      cipher.init(Cipher.ENCRYPT_MODE, keyspec, 
        new IvParameterSpec(iv))
    cipher.doFinal(data)
  }
  
  def decrypt(encdata: Array[Byte], key: Array[Byte],
      initvector: Array[Byte]): String = {
    val keyspec = new SecretKeySpec(key, "AES")
    val cipher = Cipher.getInstance("AES/CBC/PKCS5PADDING")
    cipher.init(Cipher.DECRYPT_MODE, keyspec, 
      new IvParameterSpec(initvector))
    val decrypted = cipher.doFinal(encdata)
    new String(decrypted, 0, decrypted.length)
  }
}

At the end of the run (it takes a while to finish), you should have about 600,000 searchable records in Solr, and two tables (users and emails) in the MongoDB collection solr4secure containing the encryption keys and encrypted documents respectively.

Search


As I mentioned above, the main work on the search (Solr) side is a search component which is tacked on to the end of the standard Request Handler chain to create a new service. This is configured in Solr's solrconfig.xml file (under solr/collections1/conf) like so:

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
<searchComponent class="com.mycompany.solr4extras.secure.DecryptComponent" 
      name="decrypt">
    <str name="config-file">secure.properties</str>
  </searchComponent>

  <requestHandler name="/secure_select" class="solr.SearchHandler">
     <lst name="defaults">
       <str name="echoParams">explicit</str>
       <int name="rows">10</int>
       <str name="fl">*</str>
     </lst>
     <arr name="last-components">
       <str>decrypt</str>
     </arr>
  </requestHandler>

The code for my custom DecryptingComponent is shown below. It is SolrCoreAware so it can read the configuration file (secure.properties, which is dropped into the solr/collections1/conf directory) and creates a connection to MongoDB. In the process method, it will read the response from the previous component, lookup the message_ids using the docID, then lookup the encrypted documents from MongoDB using the message_ids, decrypt the documents and replace the Solr response with its own.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
// Source: src/main/scala/com/mycompany/solr4extras/secure/DecryptComponent.scala
package com.mycompany.solr4extras.secure

import java.io.{FileInputStream, File}
import java.util.Properties

import scala.collection.JavaConversions.seqAsJavaList

import org.apache.solr.common.{SolrDocumentList, SolrDocument}
import org.apache.solr.core.SolrCore
import org.apache.solr.handler.component.{SearchComponent, ResponseBuilder}
import org.apache.solr.response.ResultContext
import org.apache.solr.util.plugin.SolrCoreAware

class DecryptComponent extends SearchComponent with SolrCoreAware {

  var mongoDao: MongoDao = null
  var defaultFieldList: List[String] = null
  
  def getDescription(): String = 
    "Decrypts record with AES key for user identified by email"

  def getSource(): String = "DecryptComponent.scala"

  override def inform(core: SolrCore): Unit = {
    val props = new Properties()
    props.load(new FileInputStream(new File(
      core.getResourceLoader.getConfigDir, "secure.properties")))
    val host = props.get("mongo.host").asInstanceOf[String]
    val port = Integer.valueOf(props.get("mongo.port").asInstanceOf[String])
    val db = props.get("mongo.db").asInstanceOf[String]
    mongoDao = new MongoDao(host, port, db)
    defaultFieldList = props.get("default.fl").asInstanceOf[String].
      split(",").toList
  }
  
  override def prepare(rb: ResponseBuilder): Unit = { /* NOOP */ }

  override def process(rb: ResponseBuilder): Unit = {
    println("in process...")
    val params = rb.req.getParams
    val dfl = if (params.get("fl").isEmpty || params.get("fl") == "*") 
      defaultFieldList
      else rb.req.getParams.get("fl").split(",").toList
    val email = rb.req.getParams.get("email")
    if (! email.isEmpty) {
      // get docIds returned by previous component
      val nl = rb.rsp.getValues
      val ictx = nl.get("response").asInstanceOf[ResultContext]
      var docids = List[Integer]()
      val dociter = ictx.docs.iterator
      while (dociter.hasNext) docids = dociter.nextDoc :: docids
      // extract message_ids from the index and populate list
      val searcher = rb.req.getSearcher
      val mfl = new java.util.HashSet[String](List("message_id"))
      val messageIds = docids.reverse.map(docid => 
        searcher.doc(docid, mfl).get("message_id"))
      // populate a SolrDocumentList from index
      val solrdoclist = new SolrDocumentList
      solrdoclist.setMaxScore(ictx.docs.maxScore)
      solrdoclist.setNumFound(ictx.docs.matches)
      solrdoclist.setStart(ictx.docs.offset)
      val docs = mongoDao.getDecryptedDocs(email, dfl, messageIds).
        map(fieldmap => {
          val doc = new SolrDocument()
          fieldmap.keys.toList.map(fn => fieldmap(fn) match {
              case value: String =>
                doc.addField(fn, value.asInstanceOf[String])
              case value: List[String] => 
                value.asInstanceOf[List[String]].map(v =>
                  doc.addField(fn, v))
          })
          doc
      })
      solrdoclist.addAll(docs)
      // swap the response with the generated one
      rb.rsp.getValues().remove("response")
      rb.rsp.add("response", solrdoclist)
    }
  }
}

Custom JARs can be "plugged-in" to Solr by dropping them in Solr's lib directory (in my example solr/collections1/lib). You can create a custom JAR for the project by doing "sbt package".

Additionally, since my stuff is in Scala, I also needed to add in couple of extra JARs from the Scala distribution (scala-lib.jar and scalaj-collection.jar), as well as several JARs that my code uses. The full list of additional JARs (ls of the lib directory) is shown below:

1
2
3
4
casbah-commons_2.9.2-2.3.0.jar  mongo-java-driver-2.8.0.jar
casbah-core_2.9.2-2.3.0.jar     scala-library.jar
casbah-query_2.9.2-2.3.0.jar    scalaj-collection_2.9.1-1.2.jar
casbah-util_2.9.2-2.3.0.jar     solr4-extras_2.9.2-1.0.jar

Having done all this, its time to test the new service. The Solr server should come up cleanly on restart. Then a URL like this:

1
2
3
4
http://localhost:8983/solr/collection1/secure_select\
    ?q=body:%22hedge%20fund%22\
    &fq=from:kaye.ellis@enron.com\
    &email=kaye.ellis@enron.com

should yield a response that looks like this:


Thats all I have for this week. It was interesting for me because this is the first time I was looking at Solr4 (which looks quite impressive, congratulations to the Solr team and thanks for your hard work making this happen), the first time I wrote a Solr custom component in Scala, and also the first time using MongoDB outside doing the exercises for the M101 course. Hope you found it interesting as well.

For those interested, the full source and configuration files for this project are available on my solr4-extras project on GitHub.