Saturday, December 13, 2014

Computing a Comorbidity Network with Spark


I've been meaning to learn about Apache Spark since I read Fast Data processing with Spark and attended this meetup earlier this year. I have also been experimenting with some anonymized Claims data published by the Centers for Medicare and Medicaid Services (CMS), consisting of 1.3M inpatient and 15.8M outpatient records for 2.3M (unique) members. In order to get familiar with the Spark API, I came upon the idea of creating a network of member comorbidities, similar to the idea described in the paper Flavor Network and the Principles of Food Pairing by Ahn, et al. The Flavor Network analyzes food pairings based on shared flavor compounds, whereas my Comorbidity Network analyzes disease pairings based on shared procedure codes. While it may be intuitively simpler to use the co-occurrence frequency directly from the member data, this pairing will hopefully produce different insights. This post describes the approach and the code to create this network.

The member data is annotated with the member's comorbidities, ie, the diseases (s)he has. Each member record is annotated with up to 11 diseases using a binary indicator variable (1=Yes, 2=No). The published data is for the period 2008-2010, and there are multiple member entries for the same member, probably corresponding to data at different time points in this period. Each member can have zero or more inpatient and outpatient claims. Each claim has one or more procedure codes - in case of inpatient records, these are ICD-9 procedure codes and in case of outpatient records these are HCPCS codes. The diagram below shows the interaction between the member and claims data, with the uninteresting fields elided.


In order to transform this data into a Comorbidity Network, we need to run it through a series of transformations. The Spark API exposes a data structure called an Resilient Distributed Dataset (RDD) with methods on it that allow you to treat it like a Scala collection, very similar to Scalding. The set of transformations is summarized in the diagram below.


The implementation code is housed in a project built from the Spark Example Project template from Snowplow Analytics on GitHub, updated to use Spark 1.1 and Scala 2.10 and JUnit4 instead of Spec2 for tests. The example demonstrates use of the Spark API to implement the Word Count example, and is instrumented to generate artifacts to run the job on a Spark cluster. Here is the code for the pipeline. For ease of testing, I have modeled individual operations as independent functions. The execute() function ties these operations into the pipeline described in the diagram above.

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
// Source: src/main/scala/com/mycompany/diseasegraph/GraphDataGenerator.scala
package com.mycompany.diseasegraph

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import SparkContext._

object GraphDataGenerator {

  private val AppName = "GraphDataGeneratorJob"

  def execute(master: String, args: List[String], 
      jars: Seq[String]=Nil): Unit = {
    
    val sc = new SparkContext(master, AppName, null, jars)
    
    // dedup member so the record with the highest number
    // of comorbidities is retained
    val membersDeduped = dedupMemberInfo(sc.textFile(args(0)))
    // normalize member to (member_id,(disease,weight))
    val membersNormalized = normalizeMemberInfo(membersDeduped)

    // normalize inpatient and outpatient claims to 
    // (member_id,(code,weight))
    // This involves grouping by (member_id, claim_id) and
    // computing the code weights, then removing claim_id.
    // Codes used in inpatient claims are ICD-9 and the
    // ones used for outpatient claims are HCPCS
    val claimsNormalized = normalizeClaimInfo(
      sc.textFile(args(1)), (30, 35)) ++
      normalizeClaimInfo(sc.textFile(args(2)), (31, 75))
    
    // join the membersNormalized and claimsNormalized RDDs
    // on memberId to get mapping of disease to code
    // membersNormalized: (member_id, (disease, d_weight))
    // claimsNormalized: (member_id, (code, c_weight))
    // diseaseCodes: (disease, code, weight)
    val diseaseCodes = joinMembersWithClaims(
      membersNormalized, claimsNormalized)
    diseaseCodes.map(t => format(t)).saveAsTextFile(args(3))
    
    // finally do a self join with the diseaseCodes RDD joining
    // on code to compute a measure of disease-disease similarity
    // by the weight of the shared procedure codes to produce
    // (disease_A, disease_B, weight)
    val diseaseAllPairs = selfJoinDisease(diseaseCodes)
    diseaseAllPairs.map(t => format(t)).saveAsTextFile(args(4))
  }
  
  def format(t: (String,String,Double)): String = {
    "%s,%s,%.3f".format(t._1, t._2, t._3)
  }
  
  val columnDiseaseMap = Map(
    12 -> "ALZM",
    13 -> "CHF",
    14 -> "CKD",
    15 -> "CNCR",
    16 -> "COPD",
    17 -> "DEP",
    18 -> "DIAB",
    19 -> "IHD",
    20 -> "OSTR",
    21 -> "ARTH",
    22 -> "TIA"
  )

  def dedupMemberInfo(input: RDD[String]): 
      RDD[(String,List[String])] = {
    input.map(line => {
      val cols = line.split(",")
      val memberId = cols(0)
      val comorbs = columnDiseaseMap.keys
        .toList.sorted
        .map(e => cols(e))
      (memberId, comorbs.toList)
    })
    .reduceByKey((v1, v2) => {
      // 1 == Yes, 2 == No
      val v1fsize = v1.filter(_ == 1).size
      val v2fsize = v2.filter(_ == 1).size
      if (v1fsize > v2fsize) v1 else v2
    })
  }
  
  def normalizeMemberInfo(input: RDD[(String,List[String])]): 
      RDD[(String,(String,Double))]= {
    input.flatMap(elem => {
      val diseases = elem._2.zipWithIndex
        .map(di => if (di._1.toInt == 1) columnDiseaseMap(di._2 + 12) else "")
        .filter(d => ! d.isEmpty())
      val weight = 1.0 / diseases.size
      diseases.map(disease => (elem._1, (disease, weight)))
    })
  }
  
  def normalizeClaimInfo(input: RDD[String], 
      pcodeIndex: (Int,Int)): RDD[(String,(String,Double))] = {
    input.flatMap(line => {
      val cols = line.split(",")
      val memberId = cols(0)
      val claimId = cols(1)
      val procCodes = cols.slice(pcodeIndex._1, pcodeIndex._2)
      procCodes.filter(pcode => ! pcode.isEmpty)
        .map(pcode => ("%s:%s".format(memberId, claimId), pcode))
    })
    .groupByKey()
    .flatMap(grouped => {
      val memberId = grouped._1.split(":")(0)
      val weight = 1.0 / grouped._2.size
      val codes = grouped._2.toList
      codes.map(code => {
        (memberId, (code, weight))
      })
    })
  }
  
  def joinMembersWithClaims(members: RDD[(String,(String,Double))],
      claims: RDD[(String,(String,Double))]): 
      RDD[(String,String,Double)] = {
    members.join(claims)
      .map(rec => {
        val disease = rec._2._1._1
        val code = rec._2._2._1
        val weight = rec._2._1._2 * rec._2._2._2
        (List(disease, code).mkString(":"), weight)
      })
      .reduceByKey(_ + _)
      .sortByKey(true)
      .map(kv => {
        val Array(k,v) = kv._1.split(":")
        (k, v, kv._2)
      })
  }
  
  def selfJoinDisease(dcs: RDD[(String,String,Double)]): 
      RDD[(String,String,Double)] = {
    val dcsKeyed = dcs.map(t => (t._2, (t._1, t._3)))
    dcsKeyed.join(dcsKeyed)
      // join on code and compute edge weight 
      // between disease pairs
      .map(rec => {
        val ldis = rec._2._1._1
        val rdis = rec._2._2._1
        val diseases = Array(ldis, rdis).sorted
        val weight = rec._2._1._2 * rec._2._2._2
        (diseases(0), diseases(1), weight)
      })
      // filter out cases where LHS == RHS
      .filter(t => ! t._1.equals(t._2))
      // group on (LHS + RHS)
      .map(t => (List(t._1, t._2).mkString(":"), t._3))
      .reduceByKey(_ + _)
      .sortByKey(true)
      // convert back to triple format
      .map(p => {
        val Array(lhs, rhs) = p._1.split(":")
        (lhs, rhs, p._2)
      })
  }
}

On the member side, we start with the input files of denormalized, duplicate member data that looks like this:

1
2
3
4
5
6
7
0001448457F2ED81,19750401,,2,2,0,15,150,12,12,12,12,1,1,1,2,1,1,1,1,2,2,2,21000.00,4096.00,0.00,670.00,430.00,0.00,2070.00,450.00,70.00
0001448457F2ED81,19750401,,2,2,0,15,150,12,12,12,12,1,1,1,2,2,1,1,1,2,2,2,21000.00,4096.00,0.00,670.00,430.00,0.00,2070.00,450.00,70.00
000188A3402777A5,19220401,,2,5,0,31,270,12,12,0,12,1,1,1,1,1,1,1,1,2,2,2,0.00,0.00,0.00,2100.00,890.00,0.00,2330.00,690.00,30.00
0001A1CB9A542107,19220201,,2,1,0,45,610,12,12,0,08,2,2,2,2,2,2,1,1,2,2,2,0.00,0.00,0.00,4890.00,1460.00,0.00,1060.00,290.00,300.00
0001EB1229306825,19540501,,1,3,0,22,090,12,12,0,12,2,1,2,1,2,2,2,1,2,1,2,0.00,0.00,0.00,660.00,410.00,0.00,4470.00,1220.00,20.00
000238CD55A321A2,19110501,,2,1,0,19,350,12,12,0,12,1,2,2,2,2,1,1,2,1,2,2,0.00,0.00,0.00,0.00,0.00,0.00,750.00,180.00,0.00
...

Our first operation filters out the member_id and comorbidity indicators and groups by member_id to remove all duplicate occurrences except the one with the maximum number of comorbidities for the member. The member data is transformed to something like this:

1
2
3
4
5
6
0001EB1229306825,2,1,2,1,2,2,2,1,2,1,2
000188A3402777A5,1,1,1,1,1,1,1,1,2,2,2
0001A1CB9A542107,2,2,2,2,2,2,1,1,2,2,2
000238CD55A321A2,1,2,2,2,2,1,1,2,1,2,2
0001448457F2ED81,1,1,1,2,2,1,1,1,2,2,2
...

Our next step normalizes the data by member_id and disease, assigning a weight for the disease as the reciprocal of the number of diseases a member has. For example, the first member in the data above has 4 diseases, so each of his diseases will be assigned a weight of 0.25. At this point, the data looks like this:

1
2
3
4
5
6
0001EB1229306825,CHF,0.250
0001EB1229306825,CNCR,0.250
0001EB1229306825,IHD,0.250
0001EB1229306825,ARTH,0.250
000188A3402777A5,ALZM,0.125
...

On the claims side, each claim is keyed by member_id and claim_id. Each claim lists a bunch of procedure codes. ICD-9 codes are used for Inpatient claims and HCPCS codes are used for Outpatient claims. Processing for both claims is similar, the only difference is the code to capture the claims from the input data. Here is what outpatient claims data looks like.

1
2
3
4
5
6
00000B48BCF4AD29,391362254696220,1,20080315,20080315,10026U,70.00,0.00,2296429757,,,0.00,2721,V1250,2330,,,,,,,,,,,,,,0.00,0.00,,G0103,80076,80053,80053,85610,87086,82248,82565,84443,76700,80053,80061,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
00000B48BCF4AD29,391612254357344,1,20080405,20080405,0300YM,50.00,0.00,5824170688,,,0.00,72999,72981,,,,,,,,,,,,,,,0.00,10.00,7295,93971,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
00000B48BCF4AD29,391812254100253,1,20080604,20080621,1002GD,300.00,0.00,0631751348,,,0.00,7216,,,,,,,,,,,,,,,,0.00,40.00,7242,97110,97001,97010,97110,97110,97110,97110,97110,97110,97010,97140,97035,97110,97124,97530,97110,97012,97026,97001,97010,97140,97110,,,,,,,,,,,,,,,,,,,,,,,
00000B48BCF4AD29,391012254328356,1,20081024,20081024,1000GD,1100.00,0.00,7611012531,7611012531,,0.00,99671,4019,V1046,53081,V5861,3441,2859,,,,,,,,,,0.00,30.00,,,84484,V2632,J1170,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
00000B48BCF4AD29,391602254471197,1,20090414,20090414,1000UU,60.00,0.00,7190755841,,,0.00,7916,V7644,,,,,,,,,,,,,,,0.00,0.00,,36415,86800,85027,82043,84460,85025,82435,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
...

The first step is to filter out the member_id, claim_id and procedure codes from the data and group by (member_id+claim_id) to find the number of codes per claim and to weight them as a reciprocal of how many codes are present in the claim, similar to how we weighted the diseases. We then aggregate code weights for individual codes across claims for each member, arriving at normalized member data that looks like this:

1
2
3
4
5
6
00009C897C3D8372,82043,1.000
0000525AB30E4DEF,78315,0.333
0000525AB30E4DEF,83036,0.333
0000525AB30E4DEF,96372,0.333
0000525AB30E4DEF,99281,1.000
...

We now join the normalized member data and normalized claim data by member_id and combining the weights, to get (disease, code, weight) triples that look like this:

1
2
3
4
5
6
ALZM,7935,0.167
ALZM,9427,0.167
CHF,7935,0.167
CHF,9427,0.167
CKD,7935,0.167
...

Joining this data with itself on the code gives us the disease pairs with a score that indicates the degree of similarity between then, in other words our Comorbidity Network.

1
2
3
4
5
6
ALZM,ARTH,10.206
ALZM,CHF,22.205
ALZM,CKD,14.718
ALZM,CNCR,4.790
ALZM,COPD,11.838
...

When running this on Spark via the command line, the pipeline described would be invoked through a job class like this one. However, during development, I incrementally developed each of the operations individually and tested them using JUnit tests. I also used a JUnit test to run the full pipeline against my "local" Spark cluster on my laptop with a small subset of data consisting of 1,000 members and 5,000 each of inpatient and outpatient records. The (final) JUnit test to run the full pipeline is shown below. The full test suite is here.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// Source: src/test/scala/com/mycompany/diseasegraph/GraphDataGeneratorTest.scala
package com.mycompany.diseasegraph

class GraphDataGeneratorTest {

  val TestName = "GraphDataGeneratorTest"
    
  ...

  @Test
  def testExecute(): Unit = {
    val args = List(
      "data/benefit_summary_1000.csv",
      "data/inpatient_claims_1000.csv",
      "data/outpatient_claims_1000.csv",
      "src/test/resources/final_outputs/disease_codes",
      "src/test/resources/final_outputs/diseases"
    )
    forceDeleteIfExists(args(3))
    forceDeleteIfExists(args(4))
    GraphDataGenerator.execute(master="local", args=args)
  }
}

I then used the generated data to construct a graph using the Python script shown below, adapted from this example on the Udacity wiki. The code reads from the output of the Spark job, and constructs a NetworkX graph object and plots it using the Matplotlib library. The edge weights are normalized (so maximum edge weight is 1), and only edges with a weight greater than 0.3 are shown to reduce clutter.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Source: src/main/python/disease_graph.py
import networkx as nx
import matplotlib.pyplot as plt
import os

def draw_graph(G, labels=None, graph_layout='shell',
               node_size=1600, node_color='blue', node_alpha=0.3,
               node_text_size=12,
               edge_color='blue', edge_alpha=0.3, edge_tickness=1,
               edge_text_pos=0.3,
               text_font='sans-serif'):

    # these are different layouts for the network you may try
    # shell seems to work best
    if graph_layout == 'spring':
        graph_pos=nx.spring_layout(G)
    elif graph_layout == 'spectral':
        graph_pos=nx.spectral_layout(G)
    elif graph_layout == 'random':
        graph_pos=nx.random_layout(G)
    else:
        graph_pos=nx.shell_layout(G)
    # draw graph
    nx.draw_networkx_nodes(G,graph_pos,node_size=node_size, 
                           alpha=node_alpha, node_color=node_color)
    nx.draw_networkx_edges(G,graph_pos,width=edge_tickness,
                           alpha=edge_alpha,edge_color=edge_color)
    nx.draw_networkx_labels(G, graph_pos,font_size=node_text_size,
                            font_family=text_font)
    nx.draw_networkx_edge_labels(G, graph_pos, edge_labels=labels, 
                                 label_pos=edge_text_pos)
    # show graph
    frame = plt.gca()
    frame.axes.get_xaxis().set_visible(False)
    frame.axes.get_yaxis().set_visible(False)

    plt.show()

def add_node_to_graph(G, node, node_labels):
    if node not in node_labels:
        G.add_node(node)
    node_labels.add(node)

datadir = "../../../src/test/resources/final_outputs/diseases"
lines = []
for datafile in os.listdir(datadir):
    if datafile.startswith("part-"):
        fin = open(os.path.join(datadir, datafile), 'rb')
        for line in fin:
            disease_1, disease_2, weight = line.strip().split(",")
            lines.append((disease_1, disease_2, float(weight)))
        fin.close()

max_weight = max([x[2] for x in lines])
norm_lines = map(lambda x: (x[0], x[1], x[2] / max_weight), lines)

G = nx.Graph()
edge_labels = dict()
node_labels = set()
for line in norm_lines:
    add_node_to_graph(G, line[0], node_labels)
    add_node_to_graph(G, line[1], node_labels)
    if line[2] > 0.3:
        G.add_edge(line[0], line[1], weight=line[2])
        edge_labels[(line[0], line[1])] = "%.2f" % (line[2])
draw_graph(G, labels=edge_labels, graph_layout="shell")

This code produces the Comorbidity Network shown below.


The network is obviously inaccurate because it was generated off a subset of the input data, so we can't conclude much from it. In this initial part of the task, I deliberately concentrated on just learning the Spark API without thinking too much about how I would process the full dataset. From the little I have read on the subject, one can use either Amazon's EC2 or EMR to run Spark jobs, but neither seemed very straightforward - I plan on spending some time figuring this out later and write about it when successful.

That's all I have for today. My take on the Spark API is that it is really well-designed and has almost zero learning curve if you are already developing using Scala's Higher Order Functions. In some respects, Scalding's API is similar, but I found Spark's API to be more intuitive. If you would like to try the code out yourself, you can find it here on GitHub.

2 comments (moderated to prevent spam):

Jack Park said...

Fantastic!
I have to wonder if the comorbidity mechanism explained here can be applied to text analysis of scientific papers.

Sujit Pal said...

Thanks Jack. And yes, as long as there are two types of fields to pivot off of, it should be possible to apply this kind of analysis for text as well. I remember reading a paper where the author built a network of scientific disciplines using citations (similar to this one but not this one, I one I am thinking about had a different chart).