Deep Learning with Heterogeneous Graph Embeddings for Mortality Prediction from Electronic Health Records

Computational prediction of in-hospital mortality in the setting of an intensive care unit can help clinical practitioners to guide care and make early decisions for interventions. As clinical data are complex and varied in their structure and components, continued innovation of modeling strategies is required to identify architectures that can best model outcomes. In this work, we train a Heterogeneous Graph Model (HGM) on Electronic Health Record data and use the resulting embedding vector as additional information added to a Convolutional Neural Network (CNN) model for predicting in-hospital mortality. We show that the additional information provided by including time as a vector in the embedding captures the relationships between medical concepts, lab tests, and diagnoses, which enhances predictive performance. We find that adding HGM to a CNN model increases the mortality prediction accuracy up to 4\%. This framework serves as a foundation for future experiments involving different EHR data types on important healthcare prediction tasks.


Introduction
Timely prediction of in-hospital mortality within intensive care units (ICU) is beneficial [20,11] for practitioners to tailor care and allow for earlier interventions to prevent deterioration [6,16]. Electronic Health Record (EHR) data consist of information relating to patient encounters with a health system, such as demographics, disease diagnoses, vital signs, and medications, among others [10,8] which are often used for machine learning (ML) predictions for different tasks in the biomedical domain including mortality prediction [19,21,9]. The inherent complexity of EHR data often require advanced modeling frameworks to gain robust performance for these tasks. A common modeling approach for EHR research is a 2-dimensional convolutional neural networks (CNN) with one dimension as time and the other as clinical features [22,13,2]. In healthcare-related CNN models, various medical features are normally concatenated to be directly used as inputs and create embeddings [18,5,15]. This form of feature representation can be powerful, but disregards the graphical structure and interconnectivity between medical concepts [4,3] which can affect the CNN performance especially since EHR data is often sparse due to missingness [2].
In this work, we propose a Heterogeneous Graph Model (HGM) to create a patient embedding vector, which better accounts for missingness in data for training a CNN model. The HGM model captures the relationships between different medical concept types (e.g., diagnoses and lab tests) due to its graphical structure. This relational representation facilitates capturing more complex patient patterns and encoding similarities.

Dataset
We conduct our experiments on de-identified EHR data from MIMIC-III [12]. This data set contains various clinical data relating to patient admission to ICU, such as demographics, lab test results, and disease diagnoses. We collected data for 5,956 patients, extracting lab tests every hour from admission. There are a total of 409 unique lab tests and 3,387 unique disease diagnoses observed. We bin the lab test events into 6, 12, 24, and 48 hours prior to patient death or discharge from ICU. From these data, we perform mortality predictions that are 10-fold, cross validated.

Convolutional Neural Network Model
CNNs are often used, and perform well, on image processing tasks [14] due to their inherent feature extraction and abstraction ability, which increases the accuracy for classification tasks. There are also studies that have demonstrated encouraging successes in using CNN for EHR analyses. In this work, we use an standard CNN model as the baseline.
Since CNNs typically require two dimensional inputs, we treat time as the horizontal dimension and medical events as the vertical dimension. For the time dimension, we record every event with one-hour binned increments with respect to the patient death or discharge time. In this model, the vertical dimension is constructed by concatenating two medical event vectors: lab tests and diagnoses. Every entry of the lab test vector records the value of a specific lab test by hour. For the diagnosis vector, the i-th entry is 1 if the i-th diagnosis is observed, otherwise 0. We treat mortality prediction as a binary classification, for which we use a softmax layer with two dimensions and cross-entropy for loss.

Heterogeneous Graph Model
The features used in baseline CNN model are essentially raw data concatenated together, which does not consider the relationships between medical concepts. We use an HGM to capture these inherent relationships by creating three different type of nodes: patient, lab test, and diagnosis. These different types of nodes are connected by two relation types: tested and diagnosed. These could be represented with two triples: the testing relationship shows whether a specific lab test was given to a patient at a specific time, the diagnosed relationship shows whether a patient was diagnosed with a disease.
To represent the lab test and diagnosis node types, we use multi-hot encoding vector: X l ∈ {0, 1} 409 and X d ∈ {0, 1} 3387 , the i-th entry with the value of 1 indicates whether a specific lab test was performed or a specific diagnosis was given.

Node Embeddings
For capturing the relations between different medical events related to a patient, we utilize the TransE model [1] to project different type of nodes into the same latent space, then classify those nodes that are connected as a similar group and the disconnected nodes as a dissimilar group.
The TransE model uses a set of 1) projection matrices and 2) relation vectors. After initialization, projections and translations are optimized end-to-end. Heterogeneous nodes X p , X l , X d are projected into a shared latent space with trainable projection matrices W p , W i , W d using the nonlinear mappings: Where σ is a non-linear activation function and c p , c i , c d are the latent representations of each type of node. Despite the fact that the EHR uses different dimensions for different data types 3 X p , X i , X d , all nodes types are projected into the same latent space. Then we apply translation operations to link these different types of nodes: where r ip and r pd are the relation vectors connecting patients to lab tests and diagnoses, respectively. Both c p and c p use the same projection matrix W p .

Optimization Model
For training the HGM, we apply a skip-gram optimization model [7],which increases the proximity between embedding points whose corresponding graph nodes are often connected after the projection and translation operations: where N t (u) are the neighborhood vertices of center node u, and t ∈ T V is the node type. Here, we learn the node embeddings by maximizing the probability of correctly predicting the the patient node's associated lab tests and diagnoses. The prediction probability is modeled as a softmax function: where u is the latent representation of patient u, c t is the latent representation of lab and diagnosis neighbors of node u, and c t · u is the inner product of the two embedding vectors representing their similarity. Z u is the normalization term Z u = v∈V e v t · u that is a sum over all vertices V, each of which is represented as v t including all node types. Therefore, equation 3 is simplified to: Numerical computation of Z u is intractable for large-scale graphs. So we adopt negative sampling strategy [17] to approximate the normalization factor. We eventually use the following optimization function: where σ(x) = 1 1+exp(−x) , K is the number of negative samples. P v (c j ) is the negative sampling distribution.
For training HGM, we perform heterogeneous neighborhood sampling by its one-hop connectivity, and pick Patient node as the center node, since it has one-hop connections to both Diagnoses and Lab test nodes. Specifically, for one training center Patient node, we uniformly sampled 10 Diagnoses one-hop direct connected nodes, and 10 Lab test one-hop direct connected nodes. From these sampled 10 Diagnoses nodes, we sample another 10 Patient nodes, each having connections with each of the prior 10 Diagnoses nodes. In this way, we connect the 4 center patient node with similar other Patient nodes by their common diagnoses. We also sample the patient node which belongs to the next hour corresponding to the center Patient node. For negative sampling [17], we perform uniform sampling through all Diagnoses node and Lab test nodes that do not have one-hop connections with the center training patient node. We then project these different nodes into same latent space through TransE model. After unifying the embeddings for different node types, each concept is represented as a point in a Euclidean space. In this space, we can measure the similarity between any two vectors using dot product.

HGM Embeddings with CNN Model
The HGM embedding vector encodes not only a patient's information, but also their relation with diagnoses, lab tests, and subsequent lab test results in time. The patient node is represented as a vector X p ∈ R 477 containing the numerical values measured from lab tests averaged at that time step. We concatenate the resulting embedding vectors to feed into the baseline CNN vertical feature dimension to form a final feature vector within every hour, and use these new features as the CNN input to predict mortality. In addition, since we encode time as a relation type, we can infer the embedding vector of time steps with missing data based on information from the previous hour. We visualize this procedure in Fig 1.

Experiments
We aim to predict mortality 6, 12, 24, and 48 hours prior to death and/or discharge. The CNN model is used for prediction as introduced in section 2.2. We compare three different scenarios to test the impact of adding HGM embedding vectors as additional features to the framework: • HGM: Embed patient labs and diagnosis raw data.
• CNN: Use raw lab test feature.
• HGM+CNN: Concatenate the HGM patient embedding vector, and the raw lab test feature vector.
In this work, we use AUROC and AUPRC scores as the primary performance metric. We tabulate the results in Table2 and we show the evaluation AUROC and AUPRC curves for these tasks in Fig.2 The testing results shows that the HGM+CNN outperforms both the basic HGM and CNN models, indicating the additional information added from the HGM patient embeddings increase 5  the accuracy of predicting in-patient mortality. The prediction accuracy of using different hours prior to death and/or discharge does not vary by much, indicating that different time windows do not have a major impact on the result for this particular task and modeling strategy. The prediction accuracy in the CNN model drops by 1% in the case of six hours prior to death and/or discharge, but not in the other two models, indicating that using the embedding features from HGM model is slightly more robust than the raw data.

Discussion and Conclusion
In this work, we propose a method to incorporate patient embedding vector from a HGM model into a CNN model in order to provide more information via interconnectivity between different clinical concepts. We assess the value of this implementation on a task of predicting mortality in EHR data. The results of our experiment shows the superior performance of adding the additional patient embedding vector, which is pretrained from the HGM model, compared to pure raw features as the input to CNN model. In one aspect, this is due to the fact that the HGM embedding vector captures additional relational information between different medical concepts, thus providing additional information to CNN model. Furthermore, we observe that concatenating the HGM embedding vector with diagnosis feature vectors does not increase the accuracy versus using the concatenation between raw lab test and diagnosis feature vectors. This finding indicates that the raw lab test feature vector can provide unique information for CNN to utilize. At the same time, this finding indicates that 6 the embedded patient vector from HGM model could lose some information from the raw lab test feature along the process of projecting these data into a low dimensional latent space. By concatenating all feature vectors, we aim to preserve the information from different data points, which helps to achieve higher mortality prediction accuracy. We hope the findings from this work can be expanded in future directions that may add more EHR node types and time components on a variety of other important health-related predictive tasks.

Author Contributions
TW designed the study and performed the analyses. TW and BSG wrote the manuscript. TW, HH, AA, YD, and BSG evaluated the results and edited the manuscript. YD, AA, and BSG supervised the project.