-
Notifications
You must be signed in to change notification settings - Fork 1
/
counterfactual-fairness-in-java.html
54 lines (54 loc) · 22.2 KB
/
counterfactual-fairness-in-java.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
<!doctype html><html lang=en-uk><head><script data-goatcounter=https://ruivieira-dev.goatcounter.com/count async src=//gc.zgo.at/count.js></script><script src=https://unpkg.com/@alpinejs/intersect@3.x.x/dist/cdn.min.js></script><script src=https://unpkg.com/alpinejs@3.x.x/dist/cdn.min.js></script><script type=module src=https://ruivieira.dev/js/deeplinks/deeplinks.js></script><link rel=preload href=https://ruivieira.dev/lib/fonts/fa-brands-400.woff2 as=font type=font/woff2 crossorigin=anonymous><link rel=preload href=https://ruivieira.dev/lib/fonts/fa-regular-400.woff2 as=font type=font/woff2 crossorigin=anonymous><link rel=preload href=https://ruivieira.dev/lib/fonts/fa-solid-900.woff2 as=font type=font/woff2 crossorigin=anonymous><link rel=preload href=https://ruivieira.dev/fonts/firacode/FiraCode-Regular.woff2 as=font type=font/woff2 crossorigin=anonymous><link rel=preload href=https://ruivieira.dev/fonts/vollkorn/Vollkorn-Regular.woff2 as=font type=font/woff2 crossorigin=anonymous><link rel=stylesheet href=https://ruivieira.dev/css/kbd.css type=text/css><meta charset=utf-8><meta http-equiv=X-UA-Compatible content="IE=edge"><title>Counterfactual Fairness in Java · Rui Vieira</title>
<link rel=canonical href=https://ruivieira.dev/counterfactual-fairness-in-java.html><meta name=viewport content="width=device-width,initial-scale=1"><meta name=robots content="all,follow"><meta name=googlebot content="index,follow,snippet,archive"><meta property="og:title" content="Counterfactual Fairness in Java"><meta property="og:description" content="Here we will look at how to build a counterfactually fair model, as detailed in Counterfactual Fairness, specifically the “Fair Add” model.
This implementation will rely mostly on Apache Commons Math1 linear regression implementations, namely the Ordinary Least Squares (OLS) regression2. We start then by adding the relevant Maven dependencies:
<dependency> <groupId>org.apache.commons</groupId> <artifactId>commons-math3</artifactId> <version>3.6.1</version> </dependency> Data will be passed as a RealMatrix3. This matrix will have dimensions $N\times f$, where $N$ is the number of observations and $f$ is the number of features."><meta property="og:type" content="article"><meta property="og:url" content="https://ruivieira.dev/counterfactual-fairness-in-java.html"><meta property="article:section" content="posts"><meta property="article:modified_time" content="2023-05-28T11:45:56+01:00"><meta name=twitter:card content="summary"><meta name=twitter:title content="Counterfactual Fairness in Java"><meta name=twitter:description content="Here we will look at how to build a counterfactually fair model, as detailed in Counterfactual Fairness, specifically the “Fair Add” model.
This implementation will rely mostly on Apache Commons Math1 linear regression implementations, namely the Ordinary Least Squares (OLS) regression2. We start then by adding the relevant Maven dependencies:
<dependency> <groupId>org.apache.commons</groupId> <artifactId>commons-math3</artifactId> <version>3.6.1</version> </dependency> Data will be passed as a RealMatrix3. This matrix will have dimensions $N\times f$, where $N$ is the number of observations and $f$ is the number of features."><link rel=stylesheet href=https://ruivieira.dev/css/styles.css><!--[if lt IE 9]><script src=https://oss.maxcdn.com/html5shiv/3.7.2/html5shiv.min.js></script><script src=https://oss.maxcdn.com/respond/1.4.2/respond.min.js></script><![endif]--><link rel=icon type=image/png href=https://ruivieira.dev/images/favicon.ico></head><body class="max-width mx-auto px3 ltr" x-data="{currentHeading: undefined}"><div class="content index py4"><div id=header-post><a id=menu-icon href=#><i class="fas fa-eye fa-lg"></i></a>
<a id=menu-icon-tablet href=#><i class="fas fa-eye fa-lg"></i></a>
<a id=top-icon-tablet href=# onclick='$("html, body").animate({scrollTop:0},"fast")' style=display:none aria-label="Top of Page"><i class="fas fa-chevron-up fa-lg"></i></a>
<span id=menu><span id=nav><ul><li><a href=https://ruivieira.dev/>Home</a></li><li><a href=https://ruivieira.dev/blog/>Blog</a></li><li><a href=https://ruivieira.dev/draw/>Drawings</a></li><li><a href=https://ruivieira.dev/map/>All pages</a></li><li><a href=https://ruivieira.dev/search.html>Search</a></li></ul></span><br><div id=share style=display:none></div><div id=toc><h4>Contents</h4><nav id=TableOfContents><ul></ul></nav><h4>Related</h4><nav><ul><li class="header-post toc"><span class=backlink-count>1</span>
<a href=https://ruivieira.dev/fairness-in-machine-learning.html>Fairness in Machine Learning</a></li></ul></nav></div></span></div><article class=post itemscope itemtype=http://schema.org/BlogPosting><header><h1 class=posttitle itemprop="name headline">Counterfactual Fairness in Java</h1><div class=meta><div class=postdate>Updated <time datetime="2023-05-28 11:45:56 +0100 BST" itemprop=datePublished>2023-05-28</time>
<span class=commit-hash>(<a href=https://ruivieira.dev/log/index.html#1454e0a>1454e0a</a>)</span></div></div></header><div class=content itemprop=articleBody><p>Here we will look at how to build a counterfactually fair model, as detailed in <a href=https://ruivieira.dev/counterfactual-fairness.html>Counterfactual Fairness</a>, specifically the “Fair Add” model.</p><p>This implementation will rely mostly on Apache Commons Math<sup id=fnref:1><a href=#fn:1 class=footnote-ref role=doc-noteref>1</a></sup> linear regression implementations, namely the Ordinary Least Squares (OLS) regression<sup id=fnref:2><a href=#fn:2 class=footnote-ref role=doc-noteref>2</a></sup>. We start then by adding the relevant Maven dependencies:</p><div class=highlight><pre tabindex=0 style=background-color:#fff;-moz-tab-size:4;-o-tab-size:4;tab-size:4><code class=language-xml data-lang=xml><span style=display:flex><span><span style=color:navy><dependency></span>
</span></span><span style=display:flex><span> <span style=color:navy><groupId></span>org.apache.commons<span style=color:navy></groupId></span>
</span></span><span style=display:flex><span> <span style=color:navy><artifactId></span>commons-math3<span style=color:navy></artifactId></span>
</span></span><span style=display:flex><span> <span style=color:navy><version></span>3.6.1<span style=color:navy></version></span>
</span></span><span style=display:flex><span><span style=color:navy></dependency></span>
</span></span></code></pre></div><p>Data will be passed as a <code>RealMatrix</code><sup id=fnref:3><a href=#fn:3 class=footnote-ref role=doc-noteref>3</a></sup>. This matrix will have dimensions $N\times f$, where $N$ is the number of observations and $f$ is the number of features.</p><p>We can instatiate the model using</p><div class=highlight><pre tabindex=0 style=background-color:#fff;-moz-tab-size:4;-o-tab-size:4;tab-size:4><code class=language-java data-lang=java><span style=display:flex><span><span style=color:#998;font-style:italic>// RealMatrix data = ...</span><span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb></span><span style=font-weight:700>final</span><span style=color:#bbb> </span>CounterfactuallyFairModel<span style=color:#bbb> </span>model<span style=color:#bbb> </span><span style=font-weight:700>=</span><span style=color:#bbb> </span><span style=font-weight:700>new</span><span style=color:#bbb> </span>CounterfactuallyFairModel(data);<span style=color:#bbb>
</span></span></span></code></pre></div><p>We will then need the following information:</p><ul><li>The protected attributes indices, <code>protectedIndices</code></li><li>The variable indices, <code>variableIndices</code></li><li>The target variable index, <code>targetIndex</code></li></ul><p>Assuming that we have the same variables as in the <a href=https://ruivieira.dev/counterfactual-fairness.html>counterfactual fairness example</a>, let’s say that the protected attributes have in the <code>data</code> matrix, column numbers <code>5, 6, 7, 8, 9, 10, 11, 12, 13, 14</code> and the model variables (<code>LSAT</code> and <code>UGPA</code>) have indices <code>1, 0</code> and the target (<code>ZFYA</code>) has index <code>2</code>. We then calculate the counterfactually fair model using:</p><div class=highlight><pre tabindex=0 style=background-color:#fff;-moz-tab-size:4;-o-tab-size:4;tab-size:4><code class=language-java data-lang=java><span style=display:flex><span>model.<span style=color:teal>calculate</span>(<span style=font-weight:700>new</span><span style=color:#bbb> </span><span style=color:#458;font-weight:700>int</span><span style=font-weight:700>[]</span>{5,<span style=color:#bbb> </span>6,<span style=color:#bbb> </span>7,<span style=color:#bbb> </span>8,<span style=color:#bbb> </span>9,<span style=color:#bbb> </span>10,<span style=color:#bbb> </span>11,<span style=color:#bbb> </span>12,<span style=color:#bbb> </span>13,<span style=color:#bbb> </span>14},<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span><span style=font-weight:700>new</span><span style=color:#bbb> </span><span style=color:#458;font-weight:700>int</span><span style=font-weight:700>[]</span>{1,<span style=color:#bbb> </span>0},<span style=color:#bbb> </span>2);<span style=color:#bbb>
</span></span></span></code></pre></div><p>The <code>calculate</code> method performs the following:</p><div class=highlight><pre tabindex=0 style=background-color:#fff;-moz-tab-size:4;-o-tab-size:4;tab-size:4><code class=language-java data-lang=java><span style=display:flex><span><span style=font-weight:700>public</span><span style=color:#bbb> </span><span style=color:#458;font-weight:700>void</span><span style=color:#bbb> </span><span style=color:#900;font-weight:700>calculate</span>(<span style=color:#458;font-weight:700>int</span><span style=font-weight:700>[]</span><span style=color:#bbb> </span>protectedIndices,<span style=color:#bbb> </span><span style=color:#458;font-weight:700>int</span><span style=font-weight:700>[]</span><span style=color:#bbb> </span>variableIndices,<span style=color:#bbb> </span><span style=color:#458;font-weight:700>int</span><span style=color:#bbb> </span>targetIndex)<span style=color:#bbb> </span>{<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span><span style=font-weight:700>final</span><span style=color:#bbb> </span>RealMatrix<span style=color:#bbb> </span>residuals<span style=color:#bbb> </span><span style=font-weight:700>=</span><span style=color:#bbb> </span><span style=font-weight:700>new</span><span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span>Array2DRowRealMatrix(<span style=font-weight:700>this</span>.<span style=color:teal>data</span>.<span style=color:teal>getRowDimension</span>(),<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span>variableIndices.<span style=color:teal>length</span>);<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span><span style=font-weight:700>for</span><span style=color:#bbb> </span>(<span style=color:#458;font-weight:700>int</span><span style=color:#bbb> </span>i<span style=color:#bbb> </span><span style=font-weight:700>=</span><span style=color:#bbb> </span>0;<span style=color:#bbb> </span>i<span style=color:#bbb> </span><span style=font-weight:700><</span><span style=color:#bbb> </span>variableIndices.<span style=color:teal>length</span>;<span style=color:#bbb> </span>i<span style=font-weight:700>++</span>)<span style=color:#bbb> </span>{<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span><span style=font-weight:700>final</span><span style=color:#bbb> </span><span style=color:#458;font-weight:700>int</span><span style=color:#bbb> </span>index<span style=color:#bbb> </span><span style=font-weight:700>=</span><span style=color:#bbb> </span>variableIndices<span style=font-weight:700>[</span>i<span style=font-weight:700>]</span>;<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span><span style=font-weight:700>final</span><span style=color:#bbb> </span>RealVector<span style=color:#bbb> </span>varResidual<span style=color:#bbb> </span><span style=font-weight:700>=</span><span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span><span style=font-weight:700>this</span>.<span style=color:teal>calculateEpsilon</span>(protectedIndices,<span style=color:#bbb> </span>index);<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span>residuals.<span style=color:teal>setColumn</span>(i,<span style=color:#bbb> </span>varResidual.<span style=color:teal>toArray</span>());<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span>}<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span><span style=color:#998;font-style:italic>// predict target from residuals </span><span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span><span style=font-weight:700>final</span><span style=color:#bbb> </span>OLSMultipleLinearRegression<span style=color:#bbb> </span>regression<span style=color:#bbb> </span><span style=font-weight:700>=</span><span style=color:#bbb> </span><span style=font-weight:700>new</span><span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span>OLSMultipleLinearRegression();<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span>regression.<span style=color:teal>newSampleData</span>(<span style=font-weight:700>this</span>.<span style=color:teal>data</span>.<span style=color:teal>getColumn</span>(targetIndex),<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span>residuals.<span style=color:teal>getData</span>());<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb></span>}<span style=color:#bbb>
</span></span></span></code></pre></div><p>As in <a href=https://ruivieira.dev/counterfactual-fairness.html>counterfactual Fairness</a> , we calculate a regression model to predict each of the variable (<code>LSAT</code> and <code>UGPA</code>) using the protected variables. The resulting residuals, $\epsilon_{LSAT}$ and $\epsilon_{UGPA}$ will in turn be used to calculate another regression model in order to predict the target variable <code>ZFYA</code>.</p><p>The residuals are calculated using the <code>calculateEpsilon</code> method, which consists of:</p><div class=highlight><pre tabindex=0 style=background-color:#fff;-moz-tab-size:4;-o-tab-size:4;tab-size:4><code class=language-java data-lang=java><span style=display:flex><span><span style=font-weight:700>public</span><span style=color:#bbb> </span>RealVector<span style=color:#bbb> </span><span style=color:#900;font-weight:700>calculateEpsilon</span>(<span style=color:#458;font-weight:700>int</span><span style=font-weight:700>[]</span><span style=color:#bbb> </span>protectedIndices,<span style=color:#bbb> </span><span style=color:#458;font-weight:700>int</span><span style=color:#bbb> </span>targetIndex)<span style=color:#bbb> </span>{<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span><span style=color:#458;font-weight:700>int</span><span style=font-weight:700>[]</span><span style=color:#bbb> </span>protectedRows<span style=color:#bbb> </span><span style=font-weight:700>=</span><span style=color:#bbb> </span><span style=font-weight:700>new</span><span style=color:#bbb> </span><span style=color:#458;font-weight:700>int</span><span style=font-weight:700>[</span><span style=font-weight:700>this</span>.<span style=color:teal>data</span>.<span style=color:teal>getRowDimension</span>()<span style=font-weight:700>]</span>;<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span>Arrays.<span style=color:teal>setAll</span>(protectedRows,<span style=color:#bbb> </span>i<span style=color:#bbb> </span><span style=font-weight:700>-></span><span style=color:#bbb> </span>i);<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span><span style=font-weight:700>final</span><span style=color:#bbb> </span>RealMatrix<span style=color:#bbb> </span>_x<span style=color:#bbb> </span><span style=font-weight:700>=</span><span style=color:#bbb> </span><span style=font-weight:700>this</span>.<span style=color:teal>data</span>.<span style=color:teal>getSubMatrix</span>(protectedRows,<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span>protectedIndices);<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span><span style=font-weight:700>final</span><span style=color:#bbb> </span>RealVector<span style=color:#bbb> </span>_y<span style=color:#bbb> </span><span style=font-weight:700>=</span><span style=color:#bbb> </span><span style=font-weight:700>this</span>.<span style=color:teal>data</span>.<span style=color:teal>getSubMatrix</span>(protectedRows,<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span><span style=font-weight:700>new</span><span style=color:#bbb> </span><span style=color:#458;font-weight:700>int</span><span style=font-weight:700>[]</span>{targetIndex}).<span style=color:teal>getColumnVector</span>(0);<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span><span style=font-weight:700>final</span><span style=color:#bbb> </span>OLSMultipleLinearRegression<span style=color:#bbb> </span>regression<span style=color:#bbb> </span><span style=font-weight:700>=</span><span style=color:#bbb> </span><span style=font-weight:700>new</span><span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span>OLSMultipleLinearRegression();<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span>regression.<span style=color:teal>newSampleData</span>(_y.<span style=color:teal>toArray</span>(),<span style=color:#bbb> </span>_x.<span style=color:teal>getData</span>());<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb> </span><span style=font-weight:700>return</span><span style=color:#bbb> </span><span style=font-weight:700>new</span><span style=color:#bbb> </span>ArrayRealVector(regression.<span style=color:teal>estimateResiduals</span>());<span style=color:#bbb>
</span></span></span><span style=display:flex><span><span style=color:#bbb></span>}<span style=color:#bbb>
</span></span></span></code></pre></div><p>Which simply calculates a regression model for the variables using the protected attributes and returning a <code>RealVector</code> with the residual $\epsilon$.</p><div class=footnotes role=doc-endnotes><hr><ol><li id=fn:1><p><a href=https://commons.apache.org/proper/commons-math/>https://commons.apache.org/proper/commons-math/</a>. <a href=#fnref:1 class=footnote-backref role=doc-backlink>↩︎</a></p></li><li id=fn:2><p><a href=https://commons.apache.org/proper/commons-math/javadocs/api-3.6/org/apache/commons/math3/stat/regression/OLSMultipleLinearRegression.html>https://commons.apache.org/proper/commons-math/javadocs/api-3.6/org/apache/commons/math3/stat/regression/OLSMultipleLinearRegression.html</a> <a href=#fnref:2 class=footnote-backref role=doc-backlink>↩︎</a></p></li><li id=fn:3><p><a href=https://commons.apache.org/proper/commons-math/javadocs/api-3.6/org/apache/commons/math3/linear/RealMatrix.html>https://commons.apache.org/proper/commons-math/javadocs/api-3.6/org/apache/commons/math3/linear/RealMatrix.html</a> <a href=#fnref:3 class=footnote-backref role=doc-backlink>↩︎</a></p></li></ol></div></div></article><div id=footer-post-container><div id=footer-post><div id=nav-footer style=display:none><ul><li><a href=https://ruivieira.dev/>Home</a></li><li><a href=https://ruivieira.dev/blog/>Blog</a></li><li><a href=https://ruivieira.dev/draw/>Drawings</a></li><li><a href=https://ruivieira.dev/map/>All pages</a></li><li><a href=https://ruivieira.dev/search.html>Search</a></li></ul></div><div id=toc-footer style=display:none><nav id=TableOfContents></nav></div><div id=share-footer style=display:none></div><div id=actions-footer><a id=menu-toggle class=icon href=# onclick='return $("#nav-footer").toggle(),!1' aria-label=Menu><i class="fas fa-bars fa-lg" aria-hidden=true></i> Menu</a>
<a id=toc-toggle class=icon href=# onclick='return $("#toc-footer").toggle(),!1' aria-label=TOC><i class="fas fa-list fa-lg" aria-hidden=true></i> TOC</a>
<a id=share-toggle class=icon href=# onclick='return $("#share-footer").toggle(),!1' aria-label=Share><i class="fas fa-share-alt fa-lg" aria-hidden=true></i> share</a>
<a id=top style=display:none class=icon href=# onclick='$("html, body").animate({scrollTop:0},"fast")' aria-label="Top of Page"><i class="fas fa-chevron-up fa-lg" aria-hidden=true></i> Top</a></div></div></div><footer id=footer><div class=footer-left>Copyright © 2024 Rui Vieira</div><div class=footer-right><nav><ul><li><a href=https://ruivieira.dev/>Home</a></li><li><a href=https://ruivieira.dev/blog/>Blog</a></li><li><a href=https://ruivieira.dev/draw/>Drawings</a></li><li><a href=https://ruivieira.dev/map/>All pages</a></li><li><a href=https://ruivieira.dev/search.html>Search</a></li></ul></nav></div></footer></div></body><link rel=stylesheet href=https://ruivieira.dev/css/fa.min.css><script src=https://ruivieira.dev/js/jquery-3.6.0.min.js></script><script src=https://ruivieira.dev/js/mark.min.js></script><script src=https://ruivieira.dev/js/main.js></script><script>MathJax={tex:{inlineMath:[["$","$"],["\\(","\\)"]]},svg:{fontCache:"global"}}</script><script type=text/javascript id=MathJax-script async src=https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js></script></html>