Saturday, December 26, 2009

A Lucene POS Tagging TokenFilter

Of late, I've been busy converting a fairly large application from Lucene 2.4 to Lucene 2.9. There are quite a few improvements - see this white paper from Lucid Imagination for a quick overview (sign-up required). Along with the many performance improvements, there are some API changes and deprecations. Some of these deprecations came with a fair amount of warning, such as the Hits object which is finally going away in Lucene 3.0, but some others are scheduled more aggressively - there is a brand new API for TokenStream in Lucene 2.9, and the old API will also be removed in Lucene 3.0.

The new API is actually quite nice - its more flexible, so you can add custom properties to the Token object during analysis, simply by adding the required TokenFilter. A TokenFilter stores user-created Attributes for a Token, which can be retrieved during Analysis. Prior to this Attribute based API, you would probably use the Token's payload to do the same thing. In any case, the new API is described in detail in the Javadocs. I thought it would be interesting to actually build a TokenFilter using the docs, so I recycled some old ideas to create a Part of Speech tagging TokenFilter.

Custom Attribute - Interface and Implementation

The first step is to create a custom attribute to hold the information about the Part of Speech. The interface simply defines getters and setters for the Pos object.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
// Source: src/main/java/net/sf/jtmt/postaggers/lucene/PosAttribute.java
package net.sf.jtmt.postaggers.lucene;

import net.sf.jtmt.postaggers.Pos;

import org.apache.lucene.util.Attribute;

/**
 * Part of Speech Attribute.
 */
public interface PosAttribute extends Attribute {
  public void setPos(Pos pos);
  public Pos getPos();
}

The implementation is the name of the interface suffixed by Impl, a convention enforced by the AttributeSource class. The class also extends AttributeImpl.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
// Source: src/main/java/net/sf/jtmt/postaggers/lucene/PosAttributeImpl.java
package net.sf.jtmt.postaggers.lucene;

import org.apache.lucene.util.AttributeImpl;

import net.sf.jtmt.postaggers.Pos;

public class PosAttributeImpl extends AttributeImpl implements PosAttribute {

  private static final long serialVersionUID = -8416041956010464591L;

  private Pos pos = Pos.OTHER;
  
  public Pos getPos() {
    return pos;
  }

  public void setPos(Pos pos) {
    this.pos = pos;
  }

  @Override
  public void clear() {
    this.pos = Pos.OTHER;
  }

  @Override
  public void copyTo(AttributeImpl target) {
    ((PosAttributeImpl) target).setPos(pos);
  }

  @Override
  public boolean equals(Object other) {
    if (other == this) {
      return true;
    }
    if (other instanceof PosAttributeImpl) {
      return pos == ((PosAttributeImpl) other).getPos();
    }
    return false;
  }

  @Override
  public int hashCode() {
    return pos.ordinal();
  }
}

TokenFilter

The Part of Speech TokenFilter is actually the meat of the class. The constructor takes in the necessary parameters to start it up, and the incrementToken() method implements the logic necessary to do the POS tagging. As I said before, I am just recycling old ideas here, so refer to the my old post if you have trouble following the code.

Essentially we open two TokenStreams one for the current token and one for the next token (since we cannot peek ahead, we need to have two streams). For each current token, we query Wordnet for the POS. If only a single POS is returned, then we just use that one. If no POS is returned, we use suffix rules to guess the POS. If multiple POS are returned, we use a set of inter-POS transition probabilities to determine the "most likely" POS for the current term.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
// Source: src/main/java/net/sf/jtmt/postaggers/lucene/PosTaggingFilter.java
package net.sf.jtmt.postaggers.lucene;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import net.sf.jtmt.clustering.ByValueComparator;
import net.sf.jtmt.postaggers.Pos;

import org.apache.commons.collections15.keyvalue.MultiKey;
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.TermAttribute;

import edu.mit.jwi.Dictionary;
import edu.mit.jwi.IDictionary;
import edu.mit.jwi.item.IIndexWord;

/**
 * Sets the POS tag for each token in the token stream. Uses Wordnet and some
 * grammar rules to make an initial determination. If Wordnet returns multiple
 * POS possibilities, we use the context surrounding the word (previous and
 * next characters) and a table of pre-calculated probabilities (from the
 * Brown corpus) to determine the most likely POS. If Wordnet returns a single 
 * POS, that is accepted. If Wordnet cannot determine the POS, then word 
 * suffix rules are used to guess the POS.
 */
public class PosTaggingFilter extends TokenFilter {

  private String prevTerm = null;
  private String nextTerm = null;

  private Pos prevPos = null;

  private TokenStream suffixStream;
  private IDictionary wordnetDictionary;
  private Map<String,Pos> suffixPosMap;
  private Map<MultiKey<Pos>,Double> transitionProbabilities;
  
  private PosAttribute posAttr;
  private TermAttribute termAttr;
  private TermAttribute suffixTermAttr;
  
  protected PosTaggingFilter(TokenStream input, 
      TokenStream suffixStream,
      String wordnetDictionaryPath, String suffixRulesPath, 
      String posTransitionProbabilitiesPath) throws Exception {
    super(input);
    // declare the POS attribute for both streams
    this.posAttr = (PosAttribute) addAttribute(PosAttribute.class);
    this.termAttr = (TermAttribute) addAttribute(TermAttribute.class);
    this.suffixTermAttr = 
      (TermAttribute) suffixStream.addAttribute(TermAttribute.class);
    this.prevTerm = null;
    this.suffixStream = suffixStream;
    this.input.reset();
    this.suffixStream.reset();
    // advance the pointer to the next token for the suffix stream
    if (this.suffixStream.incrementToken()) {
      this.nextTerm = suffixTermAttr.term();
    }
    // create artifacts for doing the POS tagging
    this.wordnetDictionary = initWordnetDictionary(wordnetDictionaryPath);
    this.suffixPosMap = initSuffixPosMap(suffixRulesPath); 
    this.transitionProbabilities = initTransitionProbabilities(
      posTransitionProbabilitiesPath);
  }

  public final boolean incrementToken() throws IOException {
    String currentTerm = null;
    if (input.incrementToken()) {
      currentTerm = termAttr.term();
    } else {
      return false;
    }
    if (suffixStream.incrementToken()) {
      this.nextTerm = suffixTermAttr.term();
    } else {
      this.nextTerm = null;
    }
    termAttr.setTermBuffer(currentTerm);
    // find the POS of the current word from Wordnet
    List<Pos> currentPoss = getPosFromWordnet(currentTerm);
    if (currentPoss.size() == 1) {
      // unambiguous match, look no further
      posAttr.setPos(currentPoss.get(0));
    } else if (currentPoss.size() == 0) {
      // wordnet could not find a POS, use suffix rules to find
      if (prevTerm != null) {
        // this is not thr first word, check for capitalization for Noun
        if (currentTerm.charAt(0) == Character.UPPERCASE_LETTER) {
          posAttr.setPos(Pos.NOUN);
        }
      }
      if (posAttr.getPos() != null) {
        Pos pos = getPosFromSuffixRules(currentTerm);
        posAttr.setPos(pos);
      }
    } else {
      // wordnet reported multiple POS, find the best one
      Pos pos = getMostProbablePos(currentPoss, nextTerm, prevPos);
      posAttr.setPos(pos);
    }
    this.prevTerm = currentTerm;
    this.prevPos = posAttr.getPos();
    return true;
  }
  
  private IDictionary initWordnetDictionary(String wordnetDictionaryPath) 
      throws Exception {
    IDictionary wordnetDictionary = new Dictionary(
      new URL("file", null, wordnetDictionaryPath));
    wordnetDictionary.open();
    return wordnetDictionary;
  }

  private Map<String, Pos> initSuffixPosMap(String suffixRulesPath) 
      throws Exception {
    Map<String,Pos> suffixPosMap = new TreeMap<String, Pos>(
      new Comparator<String>() {
        public int compare(String s1, String s2) {
          int l1 = s1 == null ? 0 : s1.length();
          int l2 = s2 == null ? 0 : s2.length();
          if (l1 == l2) {
            return 0;
          } else {
            return (l2 > l1 ? 1 : -1);
          }
        }
    });
    BufferedReader reader = new BufferedReader(new FileReader(suffixRulesPath));
    String line = null;
    while ((line = reader.readLine()) != null) {
      if (StringUtils.isEmpty(line) || line.startsWith("#")) {
        continue;
      }
      String[] suffixPosPair = StringUtils.split(line, "\t");
      suffixPosMap.put(suffixPosPair[0], Pos.valueOf(suffixPosPair[1]));
    }
    reader.close();
    return suffixPosMap;
  }

  private Map<MultiKey<Pos>,Double> initTransitionProbabilities(
      String transitionProbabilitiesPath) throws Exception {
    Map<MultiKey<Pos>,Double> transitionProbabilities = 
      new HashMap<MultiKey<Pos>,Double>();
    BufferedReader reader = new BufferedReader(
      new FileReader(transitionProbabilitiesPath));
    String line = null;
    int row = 0;
    while ((line = reader.readLine()) != null) {
      if (StringUtils.isEmpty(line) || line.startsWith("#")) {
        continue;
      }
      String[] cols = StringUtils.split(line, "\t");
      for (int col = 0; col < cols.length; col++) {
        MultiKey<Pos> key = new MultiKey<Pos>(Pos.values()[row], Pos.values()[col]);
        transitionProbabilities.put(key, Double.valueOf(cols[col]));
      }
      row++;
    }
    reader.close();
    return transitionProbabilities;
  }

  private List<Pos> getPosFromWordnet(String currentTerm) {
    List<Pos> poss = new ArrayList<Pos>();
    for (Pos pos : Pos.values()) {
      try {
        IIndexWord indexWord = wordnetDictionary.getIndexWord(
          currentTerm, Pos.toWordnetPos(pos));
        if (indexWord != null) {
          poss.add(pos);
        }
      } catch (NullPointerException e) {
        // JWI throws NPE if it cannot find a word in dictionary
        continue;
      }
    }
    return poss;
  }

  private Pos getPosFromSuffixRules(String currentTerm) {
    for (String suffix : suffixPosMap.keySet()) {
      if (StringUtils.lowerCase(currentTerm).endsWith(suffix)) {
        return suffixPosMap.get(suffix);
      }
    }
    return Pos.OTHER;
  }

  private Pos getMostProbablePos(List<Pos> currentPoss, String nextTerm,
      Pos prevPos) {
    Map<Pos,Double> posProbs = new HashMap<Pos,Double>();
    // find the possible POS values for the previous and current term
    if (prevPos != null) {
      for (Pos currentPos : currentPoss) {
        MultiKey<Pos> key = new MultiKey<Pos>(prevPos, currentPos);
        double prob = transitionProbabilities.get(key);
        if (posProbs.containsKey(currentPos)) {
          posProbs.put(currentPos, posProbs.get(currentPos) + prob);
        } else {
          posProbs.put(currentPos, prob);
        }
      }
    }
    // find the possible POS values for the current and previous term
    if (nextTerm != null) {
      List<Pos> nextPoss = getPosFromWordnet(nextTerm);
      if (nextPoss.size() == 0) {
        nextPoss.add(Pos.OTHER);
      }
      for (Pos currentPos : currentPoss) {
        for (Pos nextPos : nextPoss) {
          MultiKey<Pos> key = new MultiKey<Pos>(currentPos, nextPos);
          double prob = transitionProbabilities.get(key);
          if (posProbs.containsKey(currentPos)) {
            posProbs.put(currentPos, posProbs.get(currentPos) + prob);
          } else {
            posProbs.put(currentPos, prob);
          }
        }
      }
    }
    // now find the current Pos with the maximum probability
    if (posProbs.size() == 0) {
      return Pos.OTHER;
    } else {
      ByValueComparator<Pos,Double> bvc = 
        new ByValueComparator<Pos,Double>(posProbs);
      List<Pos> posList = new ArrayList<Pos>();
      posList.addAll(posProbs.keySet());
      Collections.sort(posList, Collections.reverseOrder(bvc));
      return posList.get(0);
    }
  }
}

Analyzer

For the Analyzer, I used a StandardAnalyzer and tacked on the POS TokenFilter to it. So as it spits out the Token with the term in it, it will also contain the PosAttribute, which the user can access if desired. The code is written as a JUnit test, since I was really testing the TokenFilter.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// Source: src/test/java/net/sf/jtmt/postaggers/lucene/PosTaggingFilterTest.java
package net.sf.jtmt.postaggers.lucene;

import java.io.Reader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.lang.StringUtils;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.standard.StandardTokenizer;
import org.apache.lucene.analysis.tokenattributes.TermAttribute;
import org.apache.lucene.util.Version;
import org.junit.Test;

/**
 * Test for the Lucene POS Tagging Filter.
 */
public class PosTaggingFilterTest {

  private String[] INPUT_TEXTS = {
    "The growing popularity of Linux in Asia, Europe, and the US is " +
    "a major concern for Microsoft.",
    "Jaguar will sell its new XJ-6 model in the US for a small fortune.",
    "The union is in a sad state.",
    "Please do not state the obvious.",
    "I am looking forward to the state of the union address.",
    "I have a bad cold today.",
    "The cold war was over long ago."
  };
  
  private Analyzer analyzer = new Analyzer() {
    @Override
    public TokenStream tokenStream(String fieldName, Reader reader) {
      return new StandardTokenizer(Version.LUCENE_CURRENT, reader);
    }
  };
  
  @Test
  public void testPosTagging() throws Exception {
    for (String inputText : INPUT_TEXTS) {
      System.out.println("Input: " + inputText);
      List<String> tags = new ArrayList<String>();
      TokenStream input = analyzer.tokenStream(
        "f", new StringReader(inputText));
      TokenStream suffixStream = 
        analyzer.tokenStream("f", new StringReader(inputText));
      input = new PosTaggingFilter(input, suffixStream,
        "/opt/wordnet-3.0/dict", "src/main/resources/pos_suffixes.txt",
        "src/main/resources/pos_trans_prob.txt");
      TermAttribute termAttribute = 
        (TermAttribute) input.addAttribute(TermAttribute.class);
      PosAttribute posAttribute = 
        (PosAttribute) input.addAttribute(PosAttribute.class);
      while (input.incrementToken()) {
        tags.add(termAttribute.term() + "/" + posAttribute.getPos());
      }
      input.end();
      input.close();
      StringBuilder tagBuf = new StringBuilder();
      tagBuf.append("Tagged: ").append(StringUtils.join(tags.iterator(), " "));
      System.out.println(tagBuf.toString());
    }
  }
}

Results

The test inputs are identical to the test in the old post, although the new code seems to have difficulty figuring out verbs. Here are the results (cleaned up a bit for display).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
Input:  The growing popularity of Linux in Asia, Europe, and the US is 
        a major concern for Microsoft.
Tagged: The/OTHER growing/ADJECTIVE popularity/NOUN of/OTHER Linux/NOUN 
        in/ADJECTIVE Asia/NOUN Europe/NOUN and/OTHER the/OTHER US/NOUN 
        is/OTHER a/NOUN major/ADJECTIVE concern/NOUN for/OTHER 
        Microsoft/OTHER

Input:  Jaguar will sell its new XJ-6 model in the US for a small fortune.
Tagged: Jaguar/NOUN will/NOUN sell/NOUN its/OTHER new/ADVERB XJ-6/OTHER 
        model/ADJECTIVE in/NOUN the/OTHER US/NOUN for/OTHER a/NOUN 
        small/ADJECTIVE fortune/NOUN

Input:  The union is in a sad state.
Tagged: The/OTHER union/NOUN is/OTHER in/ADJECTIVE a/NOUN sad/ADJECTIVE 
        state/NOUN

Input:  Please do not state the obvious.
Tagged: Please/VERB do/VERB not/ADVERB state/NOUN the/OTHER obvious/ADJECTIVE

Input:  I am looking forward to the state of the union address.
Tagged: I/ADJECTIVE am/NOUN looking/ADJECTIVE forward/NOUN to/OTHER the/OTHER
         state/NOUN of/OTHER the/OTHER union/ADJECTIVE address/NOUN

Input:  I have a bad cold today.
Tagged: I/ADJECTIVE have/NOUN a/NOUN bad/ADJECTIVE cold/NOUN today/NOUN

Input:  The cold war was over long ago.
Tagged: The/OTHER cold/ADJECTIVE war/NOUN was/OTHER over/ADVERB long/VERB 
        ago/ADJECTIVE

Although the new TokenStream API is on a short deprecation fuse, its probably not such a big deal since most people don't use it directly. From a migration point of view, even if you do use the TokenStream API directly, its probably going to be less code than say, search or indexing code, so I guess its all good.

Friday, December 04, 2009

A Unison replacement with rsync

Before Unison, I used a simple rsync script to synchronize code between my laptop and desktop. If you are interested, it is described here. The script was a simple Python wrapper over the Unix rsync command, just so I didn't have to remember all the switches.

However, the script was overly simplistic, and required some discipline to ensure that files did not get clobbered during syncing. For one, you had to start with a known "clean" state, so anytime you wanted to make a change on your laptop, you would have to download the latest code from the desktop first. Once your changes were done, you would have to remember to upload your changes in.

Having used Unison for a while now, I have gotten used to it telling me that I am about to shoot myself in the foot, rather than having to figure it out for myself. So it was something of a setback when I could not get Unison to work on my Macbook Pro (syncing against a CentOS 5.3 based desktop), but I could not go back to using the old script anymore. I decided to add some smarts to the old program so it behaved similar to Unison.

Challenges

Unison does a bidirectional sync each time it is called. One can simulate this (sort of) using a pair of rsync calls (an upsync and a downsync) using the --update switch so newer files from each side are propagated across to the other.

Relying on the file timestamps has a few problems, though. First we assume that the clocks on both machines are close enough, an assumption which is probably mostly true since most modern machines run ntpd.

Second (and perhaps more importantly), there is a chance of one of your local changes being clobbered if there is a newer version of the same file on the remote machine. This can happen in my case as the files on my remote machine (my desktop) is under CVS control, so if someone just checked in a change to the file I synced earlier and changed, a "cvs update" on the remote machine before doing the next sync will overwrite the changes on the version of the file on my laptop.

There is also the reverse case where your local changes can propagate over a remote change that was previously committed, but doing a "cvs update" before a "cvs commit" should detect that, so I am not worried so much about handling that case.

Script

To handle the local file clobbering problem, in addition to simulating the bidirectional sync with a pair of rsync calls, I also build a snapshot of the files after each sync - the snapshot is really a pickled dictionary (serialized Map for you Java guys) of the MD5 checksums for each files after the sync. On the next sync call, I use the snapshot to find which files have changed locally. Then I do a downsync in --dry-run mode and remove from the downsync file list the files that have changed locally. This prevents files that have changed locally from being overwritten by any remote changes. I then do an upsync in --dry-run mode, and remove from the changed list those files that appear in the upsync list. The remaining files are essentially "conflicts" which the program does not know what to do with, and should defer to my decision (whether to upsync, downsync or ignore).

The user-interface (i.e., the configuration files and console output) are influenced heavily by Unison's, since I wanted to reuse my profiles as much as possible. The configuration files are stored in a .sync_conf directory under the home directory, as named files with key-value pair properties.

A sample script is shown below. It identifies the local and remote root directories for this profile, and specifies the file patterns that should be excluded from the sync.

1
2
3
local=/Users/sujit/test
remote=spal@localhost:/home/spal/test
excludes=.*,target/*,bin/*

If you look at the ~/.sync_conf directory, you will also find a .dat file for each profile after the first sync is done - this is the snapshot. If you delete the snapshot, then you should make sure that you don't have any outstanding local changes (make copies) and rerun the sync.

As you can figure out from the spal@localhost prefix on the remote key value, I use a local tunnel on my laptop to connect to my desktop over ssh. Since I have to do multiple rsync calls per sync, I needed to set up passwordless ssh to avoid having to type the password in multiple times.

Here is the code - like its previous incarnation, it is written in Python. The script is heavily documented, and I have already briefly described the algorithm above, so it should not be too hard to understand.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
#!/usr/bin/python
# Does a somewhat brittle bi-directional sync. By brittle, I mean that
# this is tailored to my particular use-case rather than a general one.
# My use case is a laptop being sync'd to a desktop at work. The code
# on the desktop is under CVS control (so I can potentially recover
# clobbered files). The script tries to minimize the chance of clobbering
# files on the laptop.
#
import os
import sys
import getopt
import os.path
import cPickle
import hashlib

# CONFIG
CONF_DIR = "/Users/sujit/.sync_conf"
RSYNC_RSH = "ssh -p 9922"
# CONFIG

def usage():
  """
  Print usage information to the console and exits.
  """
  print "Usage: sync.py profile"
  print "       sync.py --list|--help"
  print "--list|-l: list available profiles"
  print "--help|-h: print this message"
  print "profile: name of profile.prf file to use"
  sys.exit(-1)

def list_profiles():
  """ 
  Print list of available profiles to the console and exits to
  the OS. Profiles are stored as .prf files in CONF_DIR.
  """
  print "Available Profiles:"
  for file in os.listdir(CONF_DIR):
    if (file.endswith(".prf")):
      print " ", file
  sys.exit(-1)

def abs_path(dirname, filename):
  """
  Convenience method to construct an absolute path for the given
  directory and file. Similar to the Java File constructor, except
  that this will not resolve properly in Windows systems (I think).
  @param dirname - the name of the directory
  @param filename - the name of the file.
  @return the absolute pathname for the file.
  """
  return os.sep.join([dirname, filename])

def get_configuration(profile):
  """
  Read the configuration off the .prf file into a dictionary
  for programmatic access.
  @param profile - the name of the sync profile.
  @return the dictionary containing configuration key-value pairs.
  """
  conf = {}
  profile_file = open(abs_path(CONF_DIR, profile + ".prf"), 'rb')
  for line in profile_file.readlines():
    (key, value) = line[:-1].split("=")
    conf[key] = value
  profile_file.close()
  return conf

def compute_file_md5_hash(file):
  """
  Computes the MD5 Hash of the named file. To avoid out of memory
  for large files, we read in the file in chunks of 1024 bytes each
  (any multiple of 128 bytes should work fine, since that is MD5's
  internal chunk size) and build up the md5 object.
  @param file - the name of the file to compute MD5 hash of.
  @return the MD5 digest of the file.
  """
  md5 = hashlib.md5()
  f = open(file, 'rb')
  while (True):
    chunk = f.read(1024)
    if (chunk == ""):
      break
    md5.update(chunk)
  f.close()
  return md5.digest()

def compute_md5_hashes(snapshot, dirname, fnames):
  """
  Visit named file and compute its MD5 hash, and store it into the
  snapshot dictionary.
  @param snapshot - a reference to the dictionary.
  @param dirname - the name of the current directory.
  @param fnames - the names of the files in the directory.
  """
  for fname in fnames:
    absname = abs_path(dirname, fname)
    if os.path.isfile(absname):
      snapshot[absname] = compute_file_md5_hash(absname)

def save_snapshot(profile, src):
  """
  Recursively traverse the directory tree rooted in src and compute
  the MD5 hash for each file. Write out the dictionary in pickled
  form to the snapshot (.dat) file.
  @param profile - the name of the sync profile.
  @param src - the local directory root.
  """
  snapshot = {}
  os.path.walk(src, compute_md5_hashes, snapshot)
  snapshot_file = open(abs_path(CONF_DIR, profile + ".dat"), 'wb')
  cPickle.dump(snapshot, snapshot_file, protocol=1)
  snapshot_file.close()

def load_snapshot(profile):
  """
  Loads the snapshot dictionary containing full path names of source
  files with their MD5 hash values from the pickled file.
  @param profile - the name of the sync profile.
  @return the dictionary of path name to MD5 hashes.
  """
  snapshot_file = open(abs_path(CONF_DIR, profile + ".dat"), 'rb')
  snapshot = cPickle.load(snapshot_file)
  snapshot_file.close()
  return snapshot

def check_if_changed(args, dirname, fnames):
  """
  Visits each file and computes the MD5 checksum, then compares it
  with the checksum available in the snapshot. If no checksum exists
  in the snapshot, it is considered to be a new file (ie, created
  since the last sync was done).
  @param args - a tuple containing the snapshot dictionary and the
             set of changed files so far.
  @param dirname - the name of the current directory.
  @param fnames - the names of the files in the directory.
  """
  (snapshot, changed_files) = args
  for fname in fnames:
    absname = abs_path(dirname, fname)
    try:
      orig_md5 = snapshot[absname]
      new_md5 = compute_file_md5_hash(absname)
      if (orig_md5 != new_md5):
        changed_files.add(absname)
    except KeyError:
      continue
    except TypeError:
      continue

def get_changed_since_sync(profile, src):
  """
  Computes a set of local file names which changed since the last time
  the sync was run. This is to prevent clobbering of local files by
  remote files containing a newer timestamp. The method walks the
  directory tree rooted in src and computes the checksum of each file
  in it, comparing it to the checksum from the snapshot. If the checksum
  differs, then it is written to the changed_files set.
  @param profile - the name of the sync profile.
  @param src - the local directory root.
  @return - a (possibly empty) set of changed file names, relative to
            the src directory.
  """
  snapshot = load_snapshot(profile)
  changed_files = set()
  os.path.walk(src, check_if_changed, (snapshot, changed_files))
  return map(lambda x: x.replace(src + os.sep, ""), changed_files)

def run_rsync_command(profile, src, dest, conf, force, files=[]):
  """ 
  Generate the rsync command for the OS to run based on input parameters.
  The output of the OS command is filtered to extract the files that
  are affected and a list of file names is returned.
  @param profile - the name of the sync profile.
  @param src - the local root.
  @param dest - the remote root.
  @param conf - a reference to the configuration dictionary.
  @param force - if set to false, rsync will be run in --dry-run mode,
                 ie, no files will be transferred.
  @param files - if provided, only the files in the list will be synced.
  @return a list of files affected.
  """
  # set up the basic command (we just add things to it for different
  # cases)
  command = " ".join(["rsync",
      "" if force else "--dry-run",
      "--cvs-exclude",
      " ".join(map(lambda x: "--exclude=" + x, conf["excludes"].split(","))),
      "--delete",
      "--update",
      "--compress",
      "-rave",
      "'" + RSYNC_RSH + "'"
  ])
  from_file_name = ""
  if (len(files) > 0):
    # create a text file and use --files-from parameter to only
    # sync files in the files-from file
    from_file_name = abs_path(CONF_DIR, profile + ".list")
    filelist = open(from_file_name, 'wb')
    for file in files:
      filelist.write(file.replace(conf["local"] + os.sep, "") + os.linesep)
    filelist.flush()
    filelist.close()
    command = " ".join([command,
      "--files-from=" + from_file_name, src + "/", dest])
  else:
    command = " ".join([command, src + "/", dest])
  # run the command
  result = []
  for line in os.popen(command):
    if (len(line.strip()) == 0 or
        line.find("file list") > -1 or
        line.find("total size") > -1 or
        (line.find("sent") > -1 and line.find("received") > -1)):
       continue
    result.append(line[:-1])
  if (len(from_file_name) > 0 and os.path.exists(from_file_name)):
    os.remove(from_file_name)
  return result

def bidirectional_sync(profile, src, dest, conf):
  """
  The algorithm consists of multiple rsync commands. Inline comments
  describe this in more detail. These checks are meant to prevent
  clobbering of local changes. The set of files that do not have a
  conflict (in either direction) are presented to the user for
  approval and two rsyncs are done. Then the conflicts are presented
  one by one. In most cases, the user should choose [u]psync.
  @param profile - the name of the sync profile. At the end of the
  sync operation, a snapshot of the current sync is stored.
  @param src - the local root.
  @param dest - the remote root.
  @param conf - the sync configuration.
  """
  # first find the local changes since the last sync. If
  # there is no .dat file, then ignore this step
  changed_since_sync = set()
  if (os.path.exists(abs_path(CONF_DIR, profile + ".dat"))):
    changed_since_sync = get_changed_since_sync(profile, src)
  # then do a dry-run of a downsync to get remote files to sync
  remote_changes = run_rsync_command(profile, dest, src, conf, False)
  # downsync only the files which are NOT in the changed_since_sync list.
  # To do this, we partition the remote_changes list into two sets
  non_conflicts, conflicts = [], []
  for remote_change in remote_changes:
    if (remote_change in changed_since_sync):
      conflicts.append(remote_change)
    else:
      non_conflicts.append(remote_change)
  remote_changes = []
  remote_changes.extend(non_conflicts)
  # do a dry-run of the upsync to get local files to upload
  local_changes = run_rsync_command(profile, src, dest, conf, False)
  # remove from conflicts that appear in changed_since_sync
  for local_change in local_changes:
    if (local_change in conflicts):
      conflicts.remove(local_change)
  # merge remote_ok and changed_since_sync, with the appropriate signage
  for remote_change in remote_changes:
    print "L<--R", remote_change
  for local_change in local_changes:
    print "L-->R", local_change
  if (len(remote_changes) + len(local_changes) > 0):
    yorn = raw_input("Is this OK [y/n/q]? ")
    if (yorn == 'y' or yorn == 'Y'):
      # do the rsync
      run_rsync_command(profile, src, dest, conf, True, local_changes)
      run_rsync_command(profile, dest, src, conf, True, remote_changes)
    elif (yorn == "q" or yorn == "Q" or yorn == "n" or yorn == "N"):
      return
  # lastly, take care of the conflicts on a per-file basis
  for conflict in conflicts:
    conflict_list = []
    action = raw_input("L<X>R " + conflict + " [u/d/n/q]? ")
    if (action == "u" or action == "U"):
      conflict_list.append(conflict)
      run_rsync_command(profile, src, dest, conf, True, conflict_list)
    elif (action == 'd' or action == 'D'):
      conflict_list.append(conflict)
      run_rsync_command(profile, dest, src, conf, True, conflict_list)
    elif (action == 'n' or action == 'N'):
      continue
    else:
      continue
  save_snapshot(profile, src)

def main():
  """
  This is how we are called. See usage() or call the script with the
  --help option for more information.
  """
  if (len(sys.argv) == 1):
    usage()
  (opts, args) = getopt.getopt(sys.argv[1:], "lh", ["list", "help"])
  for option, argval in opts:
    if (option in ("-h", "--help")):
      usage()
    elif (option in ("-l", "--list")):
      list_profiles()
  profile = sys.argv[1]
  # read the profile file
  conf = get_configuration(profile)
  # do the bidirectional sync
  bidirectional_sync(profile, conf["local"], conf["remote"], conf)

if (__name__ == "__main__"):
  main()

Usage

To get the list of profiles already available, type sync.py --list. To add or edit a profile, you have to go to the ~/.sync_conf directory and create or edit the profiles .prf file. This is actually simpler (copy an existing .prf and modify it) than doing it via a GUI.

A sample run is shown below. As you can see, it correctly detects changes on both systems. I have also tested the situation where a remote change is newer than a corresponding local change, and it successfully detects the conflict and allows me to upsync or downsync as I see fit.

1
2
3
4
5
6
7
sujit@cyclone:~$ sync.py test
L<--R ./
L<--R tunnel-indexer.prf
L-->R ./
L-->R tunnel-util.prf
Is this OK [y/n/q]? y
sujit@cyclone:~$

The script is obviously not a Unison replacement, but it works for me. I probably would start using Unison again if it became available, but until it is, this script should suffice.