Introduction to Causal Inference in Machine Learning
Written by Natalie Volk. A discussion on the paper title “Causal Inference in medicine and in health policy”.
Causal inference is a major area of research in machine learning, aiming to incorporate an understanding of cause-and-effect into AI models. By doing so, researchers believe that machine learning can help improve model generalization and transparency, help to tackle bias, and even develop human-like thinking with Artificial General Intelligence (AGI).
What is Causal Inference?
Let’s start by understanding what causal inference actually is. Causal inference is the process of determining causality — how a certain variable affects an outcome. You probably are very familiar with the concept of cause and effect, but sometimes it is difficult to identify. When there is any sort of relationship between two variables, humans are predisposed to assume a cause-and-effect relationship. This is because cause-and-effect is the easiest and simplest to understand.
In reality, there are a bunch of relationships that are merely correlations. This could be by chance, or because there is an additional variable that creates an association between the two. For example, some studies concluded that drinking coffee increases your risk of lung cancer. However, there was an additional variable that caused a correlation between the two: smoking. It turns out that there was a higher proportion of smokers among people who drank a lot of coffee than among those who did not. This variable — smoking — is known as a confounding variable, which affects other variables in a way that creates distorted, or “spurious,” relationships between them. They “confound” any genuine causal relationship between the variables.
Addressing confounding variables
The only way to truly eliminate these confounders is by holding all external variables constant and only changing the variable of interest. Then, the causal effect of this variable can be truly observed. Sadly, this is impossible — you would have to try one treatment, observe its effects, and then go back in time and repeat everything precisely except for that one variable. Because of the impossibility of this, it is known as the fundamental problem of causal inference.
Conditioning and intervening are both ways to control for confounding variables. By conditioning on a variable, you ensure that the proportion of people affected by the confounding variable in each test group is equal. Think of it as narrowing the focus to a subpopulation of cases. For example, if we’re trying to determine the effects of social media usage on academic performance, high school is a possible confounding variable because the quality of education or academic standards might be different. By conditioning, we might only look at one high school; or, if we’re looking at multiple high schools, we would ensure that different high schools are equally represented in both the control and test groups. Because you don’t need to change anything about the individuals in the test group, conditioning is a practical technique for observational studies.
Meanwhile, intervening on a variable means fixing the variable’s value such that the entire population will have that confounding variable. Unlike conditioning, this means fundamentally changing an aspect of the individuals within the population. Using the previous example with social media usage and academic performance, intervening would mean forcing all of the students in the study to switch high schools so that they’re all at the same one. Although interventions can be more effective than conditioning for identifying causal relationships, it is not feasible in observational studies and can have some ethical implications in a clinical setting.
To avoid the potential unethicalness and impracticality of clinical control trials, do-calculus can be used to intervene. The basic rules are outlined below; anyone who is interested in gaining a more thorough understanding of do-calculus can refer to Judea Pearl’s paper, “The Do-Calculus Revisited”.
The problem with machine learning
In its most common form, machine learning operates by exploring data and identifying patterns. Note that I say identifying patterns — it’s merely identifying correlations in data. Compare the relationship between a higher murder rate and increased ice cream sales with the relationship between warmer temperatures and increased ice cream sales. You can understand that the first is an associational relationship, whereas the second is probably causal. A machine learning model sees no difference. A machine learning model tries to make a prediction, and a prediction asks: “What usually happens?” given a particular set of circumstances. Meanwhile, causal inference asks: “What would happen if we intervened on the system?”
Applications of causal inference in machine learning
If machine learning could understand causality, it would aid a myriad of problems that currently plague the field, including bias, heavy data requirements, transparency, and overfitting. One major application of causal inference in machine learning is in transfer learning, specifically domain adaptation. Domain adaptation is a type of transfer learning where the task remains the same, but the exact application (or domain) is changed.
Traditionally, a model is given training data so that it can learn to predict certain outputs given certain inputs. For example, we could build a machine learning algorithm to predict whether a patient has a certain disease based on their symptoms. Let’s suppose that all the training data is from the Mayo Clinic. However, perhaps Toronto General Hospital wants to use the model. Thus, the domain is being changed — instead of the Mayo Clinic, the domain is now Toronto General Hospital. Causal inference can be used to conduct domain adaptation.
There are three main types of domain adaptation that should be analyzed:
- Target shift: estimating the fraction of sick people in the new domain and adjusting the prediction model accordingly.
- Conditional shift: determining how the change in domain affects how patients exhibit symptoms. This is more difficult to account for and typically requires domain knowledge for a strong understanding of causality.
- Generalized target shift: analyzing both target and conditional shift in the domain. This is very complex and can be implemented with a combination of different methods.
Domain adaptation is also helpful to implement a “dataset shift,” which is when there’s a mismatch between training and test data distributions. For example, a lot of military databases exist with healthcare information; however, this data disproportionately over-represent males. Similarly, wealthier individuals are more likely to be over-represented in healthcare studies. Domain adaptation can help address this problem without collecting mass amounts of additional data.
Reinforcement learning with causality
The underlying idea of reinforcement learning is that an artificial agent chooses available actions from the observed state in such a way as to collect the optimal reward. The goal is to maximize the overall reward by choosing the best possible action in every environmental interaction.
The multi-armed bandit problem is an example of a reinforcement learning agent. The agent, known as the gambler, sits in front of a row of slot machines and tries to decide which machine to play in order to maximize their overall return. Initially, however, the gambler has no idea about the probability of winning for each slot machine.
Each slot machine either gives a positive reward with probability P, or a negative reward with probability 1-P. Let’s say that you have five slot machines with probabilities of [0.1, 0.2, 0.3, 0.4, 0.5]. If the gambler knew these probabilities, they would always pick the last one; however, they do not have this information.
The greedy approach to this problem would be to take the action that has been thus far most rewarding. However, this is not optimal and there are better approaches, most notably the Thomson sampling approach. Unlike many of the other approaches, the Thomson sampling approach doesn’t simply refine an estimate of the mean reward that has already been obtained. Instead, it creates a probability distribution using information from previous rewards.
If the available rewards are binary (win or lose, as they are for the most basic multi-armed bandit scenario), a Beta distribution can be used for the probability distribution. In this distribution, α can be used to represent the number of successes, whereas β is the number of failures. The distribution has an average value of:
A Beta distribution can be created for each of the slot machines. Before the gambler begins playing on a machine, α and β can both be set to one, generating the flat line seen in red on the diagram. As the gambler plays machine more times, the standard deviation of the probability distribution decreases and the gambler gains a better understanding of the probability of the machine. However, there’s a major issue: the probability distribution could be biased due to unobserved confounders. This can be solved with causal Thomson sampling.
Causal Thomson sampling
Counterfactual reasoning in Thompson sampling could help prevent unobserved confounders from causing the agent to pick a suboptimal approach. For example, a gambler is sitting in front of two slot machines and has Beta distributions for each of the slot machines. Let’s consider two additional confounding variables: (1) the gambler may or may not be drunk and (2) the slot machines may have more or fewer flashing lights. Perhaps, if the gambler is drunk, he has a subconscious predilection to the flashing slot machine.
In this case, the gambler should determine his intuition and counter intuition. The intuition would be determined by asking: “Given that I believe the flashing slot machine is better, what would the payout be if I played that one?” Likewise, for counter intuition, the gambler can ask: “Given that I believe the flashing slot machine is better, what would the payout be if I acted differently?” By estimating the rewards of intuition versus counter-intuition giving an understanding of causality, the gambler can potentially fix their disbelief. This kind of causal Thompson sampling works in toy example dataset (see here if interested), but it is still an open research question on how to handle the issue of unobserved confounders in reinforcement learning when dealing with real-life problems.