Logistic Regression
Introduction
In classification problems, the goal is to predict the class membership based on predictors. Often there are two classes and one of the most popular methods for binary classification is logistic regression Freedman (2009).
Suppose now that \(y_i \in \{0,1\}\) is a binary class indicator. The conditional response is modeled as \(y|x \sim \mbox{Bernoulli}(g_{\beta}(x))\), where \(g_{\beta}(x) = \frac{1}{1 + e^{-x^T\beta}}\) is the logistic function, and maximize the log-likelihood function, yielding the optimization problem
\[ \begin{array}{ll} \underset{\beta}{\mbox{maximize}} & \sum_{i=1}^m \{ y_i\log(g_{\beta}(x_i)) + (1-y_i)\log(1 - g_{\beta}(x_i)) \}. \end{array} \]
CVXR
provides the logistic
atom as a shortcut for \(f(z) =
\log(1 + e^z)\) to express the optimization problem. One may be
tempted to use log(1 + exp(X %*% beta))
as in conventional
R
syntax. However, this representation of \(f(z)\) violates
the DCP composition rule, so the CVXR
parser will reject the
problem even though the objective is convex. Users who wish to employ
a function that is convex, but not DCP compliant should check the
documentation for a custom atom or consider a different formulation.
Example
The formulation is very similar to OLS, except for the specification of the objective.
In the example below, we demonstrate a key feature of CVXR
, that
of evaluating various functions of the variables that are solutions to
the optimization problem. For instance, the log-odds, \(X\hat{\beta}\),
where \(\hat{\beta}\) is the logistic regression estimate, is simply
specified as X %*% beta
below, and the getValue
function of the
result will compute its value. (Any other function of the estimate
can be similarly computed.)
n <- 20
m <- 1000
offset <- 0
sigma <- 45
DENSITY <- 0.2
set.seed(183991)
beta_true <- stats::rnorm(n)
idxs <- sample(n, size = floor((1-DENSITY)*n), replace = FALSE)
beta_true[idxs] <- 0
X <- matrix(stats::rnorm(m*n, 0, 5), nrow = m, ncol = n)
y <- sign(X %*% beta_true + offset + stats::rnorm(m, 0, sigma))
beta <- Variable(n)
obj <- -sum(logistic(-X[y <= 0, ] %*% beta)) - sum(logistic(X[y == 1, ] %*% beta))
prob <- Problem(Maximize(obj))
result <- solve(prob)
log_odds <- result$getValue(X %*% beta)
beta_res <- result$getValue(beta)
y_probs <- 1/(1 + exp(-X %*% beta_res))
We can compare with the standard stats::glm
estimate.
d <- data.frame(y = as.numeric(y > 0), X = X)
glm <- stats::glm(formula = y ~ 0 + X, family = "binomial", data = d)
est.table <- data.frame("CVXR.est" = beta_res, "GLM.est" = coef(glm))
rownames(est.table) <- paste0("$\\beta_{", 1:n, "}$")
knitr::kable(est.table, format = "html") %>%
kable_styling("striped") %>%
column_spec(1:3, background = "#ececec")
CVXR.est | GLM.est | |
---|---|---|
\(\beta_{1}\) | -0.0305494 | 0.0305494 |
\(\beta_{2}\) | 0.0023528 | -0.0023528 |
\(\beta_{3}\) | -0.0110080 | 0.0110080 |
\(\beta_{4}\) | 0.0163919 | -0.0163919 |
\(\beta_{5}\) | 0.0157186 | -0.0157186 |
\(\beta_{6}\) | 0.0006251 | -0.0006251 |
\(\beta_{7}\) | -0.0157914 | 0.0157914 |
\(\beta_{8}\) | -0.0092228 | 0.0092228 |
\(\beta_{9}\) | 0.0173823 | -0.0173823 |
\(\beta_{10}\) | 0.0019102 | -0.0019102 |
\(\beta_{11}\) | -0.0100746 | 0.0100746 |
\(\beta_{12}\) | -0.0269883 | 0.0269883 |
\(\beta_{13}\) | 0.0233625 | -0.0233625 |
\(\beta_{14}\) | 0.0009529 | -0.0009529 |
\(\beta_{15}\) | -0.0016264 | 0.0016264 |
\(\beta_{16}\) | 0.0312156 | -0.0312156 |
\(\beta_{17}\) | 0.0038949 | -0.0038949 |
\(\beta_{18}\) | -0.0121105 | 0.0121105 |
\(\beta_{19}\) | 0.0246811 | -0.0246811 |
\(\beta_{20}\) | -0.0007025 | 0.0007025 |
The sign difference is due to the coding of \(y\) as \((-1, 1)\) for
CVXR
rather than \((0, 1)\) for stats::glm
.
So, for completeness, if we were to code the \(y\) as \((0, 1)\), the objective will have to be modified as below.
obj <- -sum(X[y <= 0, ] %*% beta) - sum(logistic(-X %*% beta))
prob <- Problem(Maximize(obj))
result <- solve(prob)
beta_log <- result$getValue(beta)
est.table <- data.frame("CVXR.est" = beta_log, "GLM.est" = coef(glm))
rownames(est.table) <- paste0("$\\beta_{", 1:n, "}$")
knitr::kable(est.table, format = "html") %>%
kable_styling("striped") %>%
column_spec(1:3, background = "#ececec")
CVXR.est | GLM.est | |
---|---|---|
\(\beta_{1}\) | 0.0305494 | 0.0305494 |
\(\beta_{2}\) | -0.0023528 | -0.0023528 |
\(\beta_{3}\) | 0.0110080 | 0.0110080 |
\(\beta_{4}\) | -0.0163919 | -0.0163919 |
\(\beta_{5}\) | -0.0157186 | -0.0157186 |
\(\beta_{6}\) | -0.0006251 | -0.0006251 |
\(\beta_{7}\) | 0.0157914 | 0.0157914 |
\(\beta_{8}\) | 0.0092228 | 0.0092228 |
\(\beta_{9}\) | -0.0173823 | -0.0173823 |
\(\beta_{10}\) | -0.0019102 | -0.0019102 |
\(\beta_{11}\) | 0.0100746 | 0.0100746 |
\(\beta_{12}\) | 0.0269883 | 0.0269883 |
\(\beta_{13}\) | -0.0233625 | -0.0233625 |
\(\beta_{14}\) | -0.0009529 | -0.0009529 |
\(\beta_{15}\) | 0.0016264 | 0.0016264 |
\(\beta_{16}\) | -0.0312156 | -0.0312156 |
\(\beta_{17}\) | -0.0038949 | -0.0038949 |
\(\beta_{18}\) | 0.0121105 | 0.0121105 |
\(\beta_{19}\) | -0.0246811 | -0.0246811 |
\(\beta_{20}\) | 0.0007025 | 0.0007025 |
Now, the results match perfectly.
## Testthat Results: No output is good
Session Info
sessionInfo()
## R version 4.4.2 (2024-10-31)
## Platform: x86_64-apple-darwin20
## Running under: macOS Sequoia 15.1
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.4-x86_64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.4-x86_64/Resources/lib/libRlapack.dylib; LAPACK version 3.12.0
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## time zone: America/Los_Angeles
## tzcode source: internal
##
## attached base packages:
## [1] stats graphics grDevices datasets utils methods base
##
## other attached packages:
## [1] kableExtra_1.4.0 CVXR_1.0-15 testthat_3.2.1.1 here_1.0.1
##
## loaded via a namespace (and not attached):
## [1] gmp_0.7-5 utf8_1.2.4 clarabel_0.9.0.1 sass_0.4.9
## [5] xml2_1.3.6 slam_0.1-54 blogdown_1.19 stringi_1.8.4
## [9] lattice_0.22-6 digest_0.6.37 magrittr_2.0.3 evaluate_1.0.1
## [13] grid_4.4.2 bookdown_0.41 pkgload_1.4.0 fastmap_1.2.0
## [17] rprojroot_2.0.4 jsonlite_1.8.9 Matrix_1.7-1 ECOSolveR_0.5.5
## [21] brio_1.1.5 fansi_1.0.6 Rmosek_10.2.0 viridisLite_0.4.2
## [25] scales_1.3.0 codetools_0.2-20 jquerylib_0.1.4 cli_3.6.3
## [29] Rmpfr_0.9-5 rlang_1.1.4 Rglpk_0.6-5.1 bit64_4.5.2
## [33] munsell_0.5.1 cachem_1.1.0 yaml_2.3.10 tools_4.4.2
## [37] Rcplex_0.3-6 rcbc_0.1.0.9001 colorspace_2.1-1 gurobi_11.0-0
## [41] assertthat_0.2.1 vctrs_0.6.5 R6_2.5.1 lifecycle_1.0.4
## [45] stringr_1.5.1 bit_4.5.0 desc_1.4.3 cccp_0.3-1
## [49] pillar_1.9.0 bslib_0.8.0 glue_1.8.0 Rcpp_1.0.13-1
## [53] systemfonts_1.1.0 highr_0.11 xfun_0.49 rstudioapi_0.17.1
## [57] knitr_1.48 htmltools_0.5.8.1 rmarkdown_2.29 svglite_2.1.3
## [61] compiler_4.4.2