Prototyping the pliable lasso
Introduction
Tibshirani and Friedman (2017) propose a generalization of the lasso that allows the model coefficients to vary as a function of a general set of modifying variables, such as gender, age or time. The pliable lasso model has the form
\[ \begin{equation} \hat{y} = \beta_0{\mathbf 1} + Z\theta_0 + \sum_{j=1}^p(X_j\beta_j + W_j\theta_j) \end{equation} \]
where \(\hat{y}\) is the predicted \(N\times1\) vector, \(\beta_0\) is a scalar, \(\theta_0\) is a \(K\)-vector, \(X\) and \(Z\) are \(N\times p\) and \(N\times K\) matrices containing values of the predictor and modifying variables respectively with \(W_j=X_j \circ Z\) denoting the elementwise multiplication of Z by column \(X_j\) of \(X\).
The objective function used for pliable lasso is
\[ J(\beta_0, \theta_0, \beta, \Theta) = \frac{1}{2N}\sum_{i=1}^N (y_i-\hat{y}_i)^2 + (1-\alpha)\lambda\sum_{j=1}^p\biggl(||(\beta_j,\theta_j)||_2 + ||\theta_j||_2\biggr) + \alpha\lambda\sum_{j,k}|\theta_{j,k}|_1. \]
In the above, \(\Theta\) is a \(p\times K\) matrix of parameters with \(j\)-th row \(\theta_j\) and individual entries \(\theta_{j,k}\), \(\lambda\) is a tuning parameters. As \(\alpha \rightarrow 1\) (but \(<1\)), the solution approaches the lasso solution. The default value used is \(\alpha = 0.5.\)
An R package for the pliable lasso is forthcoming from the
authors. Nevertheless, the pliable lasso is an excellent example to
highlight the prototyping capabilities of CVXR
in research. Along
the way, we also illustrate some additional atoms that are actually
needed in this example.
The pliable lasso in CVXR
We will use a simulated example from section 3 of Tibshirani and Friedman (2017) with \(n=100\), \(p=50\) and \(K=4\). The response is generated as
\[ \begin{eqnarray*} y &=& \mu(x) + 0.5\cdot \epsilon;\ \ \epsilon \sim N(0, 1)\\ \mu(x) &=& x_1\beta_1 + x_2\beta_2 + x_3(\beta_3 e + 2z_1) + x_4\beta_4(e - 2z_2);\ \ \beta = (2, -2, 2, 2, 0, 0, \ldots) \end{eqnarray*} \]
where \(e=(1,1,\ldots , 1)^T).\)
## Simulation data.
set.seed(123)
N <- 100
K <- 4
p <- 50
X <- matrix(rnorm(n = N * p, mean = 0, sd = 1), nrow = N, ncol = p)
Z <- matrix(rbinom(n = N * K, size = 1, prob = 0.5), nrow = N, ncol = K)
## Response model.
beta <- rep(x = 0, times = p)
beta[1:4] <- c(2, -2, 2, 2)
coeffs <- cbind(beta[1], beta[2], beta[3] + 2 * Z[, 1], beta[4] * (1 - 2 * Z[, 2]))
mu <- diag(X[, 1:4] %*% t(coeffs))
y <- mu + 0.5 * rnorm(N, mean = 0, sd = 1)
It seems worthwhile to write a function that will fit the model for us so that we can customize a few things such as an intercept term, verbosity etc. The function has the following structure with comments as placeholders for code we shall construct later.
plasso_fit <- function(y, X, Z, lambda, alpha = 0.5, intercept = TRUE,
ZERO_THRESHOLD= 1e-6, verbose = FALSE) {
N <- length(y)
p <- ncol(X)
K <- ncol(Z)
beta0 <- 0
if (intercept) {
beta0 <- Variable(1) * matrix(1, nrow = N, ncol = 1)
}
## Define_Parameters
## Build_Penalty_Terms
## Compute_Fitted_Value
## Build_Objective
## Define_and_Solve_Problem
## Return_Values
}
## Fit pliable lasso using CVXR.
#pliable <- pliable_lasso(y, X, Z, alpha = 0.5, lambda = lambda)
Defining the parameters
The parameters are easy: we just have \(\beta\), \(\theta_0\) and \(\Theta\).
beta <- Variable(p)
theta0 <- Variable(K)
theta <- Variable(p, K) ; theta_transpose <- t(theta)
Note that we also define the transpose of \(\Theta\) for use later.
The penalty terms
There are three of them. The first term in the parenthesis,
\(\sum_{j=1}^p\biggl(||(\beta_j,\theta_j)||_2\biggr)\), involves components of
\(\beta\) and rows of \(\Theta\). CVXR
provides two functions to express
this norm:
hstack
to bind columns of \(\beta\) and the matrix \(\Theta\), the equivalent ofrbind
in R,cvxr_norm
which accepts a matrix variable and anaxis
denoting the axis along which the norm is to be taken. The penalty requires us to use the row as axis, soaxis = 1
per the usual R convention.
The second term in the parenthesis \(\sum_{j}||\theta_j||_2\) is also a norm along rows as the \(\theta_j\) are rows of \(\Theta\). And the last one is simply a 1-norm.
penalty_term1 <- sum(cvxr_norm(hstack(beta, theta), 2, axis = 1))
penalty_term2 <- sum(cvxr_norm(theta, 2, axis = 1))
penalty_term3 <- sum(cvxr_norm(theta, 1))
The fitted value
Equation 1 above for \(\hat{y}\) contains a sum:
\(\sum_{j=1}^p(X_j\beta_j + W_j\theta_j)\). This requires multiplication
of \(Z\) by the columns of \(X\) component-wise. That is a natural candidate
for a map-reduce combination: map the column multiplication function
appropriately and reduce using +
to obtain the XZ_term
below.
xz_theta <- lapply(seq_len(p),
function(j) (matrix(X[, j], nrow = N, ncol = K) * Z) %*% theta_transpose[, j])
XZ_term <- Reduce(f = '+', x = xz_theta)
y_hat <- beta0 + X %*% beta + Z %*% theta0 + XZ_term
The objective
The objective is now straightforward.
objective <- sum_squares(y - y_hat) / (2 * N) +
(1 - alpha) * lambda * (penalty_term1 + penalty_term2) +
alpha * lambda * penalty_term3
The problem and its solution
prob <- Problem(Minimize(objective))
result <- solve(prob, verbose = TRUE)
beta_hat <- result$getValue(beta)
The return values
We create a list with values of interest to us. However, since
sparsity is desired, we set values below ZERO_THRESHOLD
to
zero.
theta0_hat <- result$getValue(theta0)
theta_hat <- result$getValue(theta)
## Zero out stuff before returning
beta_hat[abs(beta_hat) < ZERO_THRESHOLD] <- 0.0
theta0_hat[abs(theta0_hat) < ZERO_THRESHOLD] <- 0.0
theta_hat[abs(theta_hat) < ZERO_THRESHOLD] <- 0.0
list(beta0_hat = if (intercept) result$getValue(beta0)[1] else 0.0,
beta_hat = beta_hat,
theta0_hat = theta0_hat,
theta_hat = theta_hat,
criterion = result$value)
The full function
We now put it all together.
plasso_fit <- function(y, X, Z, lambda, alpha = 0.5, intercept = TRUE,
ZERO_THRESHOLD= 1e-6, verbose = FALSE) {
N <- length(y)
p <- ncol(X)
K <- ncol(Z)
beta0 <- 0
if (intercept) {
beta0 <- Variable(1) * matrix(1, nrow = N, ncol = 1)
}
beta <- Variable(p)
theta0 <- Variable(K)
theta <- Variable(p, K) ; theta_transpose <- t(theta)
penalty_term1 <- sum(cvxr_norm(hstack(beta, theta), 2, axis = 1))
penalty_term2 <- sum(cvxr_norm(theta, 2, axis = 1))
penalty_term3 <- sum(cvxr_norm(theta, 1))
xz_theta <- lapply(seq_len(p),
function(j) (matrix(X[, j], nrow = N, ncol = K) * Z) %*% theta_transpose[, j])
XZ_term <- Reduce(f = '+', x = xz_theta)
y_hat <- beta0 + X %*% beta + Z %*% theta0 + XZ_term
objective <- sum_squares(y - y_hat) / (2 * N) +
(1 - alpha) * lambda * (penalty_term1 + penalty_term2) +
alpha * lambda * penalty_term3
prob <- Problem(Minimize(objective))
result <- solve(prob, verbose = TRUE)
beta_hat <- result$getValue(beta)
theta0_hat <- result$getValue(theta0)
theta_hat <- result$getValue(theta)
## Zero out stuff before returning
beta_hat[abs(beta_hat) < ZERO_THRESHOLD] <- 0.0
theta0_hat[abs(theta0_hat) < ZERO_THRESHOLD] <- 0.0
theta_hat[abs(theta_hat) < ZERO_THRESHOLD] <- 0.0
list(beta0_hat = if (intercept) result$getValue(beta0)[1] else 0.0,
beta_hat = beta_hat,
theta0_hat = theta0_hat,
theta_hat = theta_hat,
criterion = result$value)
}
The Results
Using \(\lambda = 0.6\) we fit the pliable lasso without an intercept
result <- plasso_fit(y, X, Z, lambda = 0.6, alpha = 0.5, intercept = FALSE)
We can print the various estimates.
cat(sprintf("Objective value: %f\n", result$criterion))
## Objective value: 4.153289
We only print the nonzero \(\beta\) values.
index <- which(result$beta_hat != 0)
est.table <- data.frame(matrix(result$beta_hat[index], nrow = 1))
names(est.table) <- paste0("$\\beta_{", index, "}$")
knitr::kable(est.table, format = "html", digits = 3) %>%
kable_styling("striped")
\(\beta_{1}\) | \(\beta_{2}\) | \(\beta_{3}\) | \(\beta_{4}\) | \(\beta_{20}\) |
---|---|---|---|---|
1.716 | -1.385 | 2.217 | 0.285 | -0.038 |
For this value of \(\lambda\), the nonzero \((\beta_1, \beta_2, \beta_3,\beta4)\) are picked up along with a few others \((\beta_{20}, \beta_{34},\beta_{39}).\)
The values for \(\theta_0\).
est.table <- data.frame(matrix(result$theta0_hat, nrow = 1))
names(est.table) <- paste0("$\\theta_{0,", 1:K, "}$")
knitr::kable(est.table, format = "html", digits = 3) %>%
kable_styling("striped")
\(\theta_{0,1}\) | \(\theta_{0,2}\) | \(\theta_{0,3}\) | \(\theta_{0,4}\) |
---|---|---|---|
-0.136 | 0.25 | -0.516 | 0.068 |
And just the first five rows of \(\Theta\), which happen to contain all the nonzero values for this result.
est.table <- data.frame(result$theta_hat[1:5, ])
names(est.table) <- paste0("$\\theta_{,", 1:K, "}$")
knitr::kable(est.table, format = "html", digits = 3) %>%
kable_styling("striped")
\(\theta_{,1}\) | \(\theta_{,2}\) | \(\theta_{,3}\) | \(\theta_{,4}\) |
---|---|---|---|
0.000 | 0.000 | 0.000 | 0.000 |
0.000 | 0.000 | 0.000 | 0.000 |
0.528 | 0.000 | 0.085 | 0.308 |
0.063 | -0.592 | -0.019 | -0.185 |
0.000 | 0.000 | 0.000 | 0.000 |
## Testthat Results: No output is good
Final comments
Typically, one would run the fits for various values of \(\lambda\) and choose one based on cross-validation and assess the prediction against a test set. Here, even a single fit takes a while, but techniques discussed in other articles here can be used to speed up the computations.
A logistic regression using a pliable lasso model can be prototyped similarly.
Session Info
sessionInfo()
## R version 4.4.2 (2024-10-31)
## Platform: x86_64-apple-darwin20
## Running under: macOS Sequoia 15.1
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.4-x86_64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.4-x86_64/Resources/lib/libRlapack.dylib; LAPACK version 3.12.0
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## time zone: America/Los_Angeles
## tzcode source: internal
##
## attached base packages:
## [1] stats graphics grDevices datasets utils methods base
##
## other attached packages:
## [1] kableExtra_1.4.0 CVXR_1.0-15 testthat_3.2.1.1 here_1.0.1
##
## loaded via a namespace (and not attached):
## [1] gmp_0.7-5 clarabel_0.9.0.1 sass_0.4.9 xml2_1.3.6
## [5] slam_0.1-54 blogdown_1.19 stringi_1.8.4 lattice_0.22-6
## [9] digest_0.6.37 magrittr_2.0.3 evaluate_1.0.1 grid_4.4.2
## [13] bookdown_0.41 fastmap_1.2.0 rprojroot_2.0.4 jsonlite_1.8.9
## [17] Matrix_1.7-1 ECOSolveR_0.5.5 brio_1.1.5 Rmosek_10.2.0
## [21] viridisLite_0.4.2 scales_1.3.0 codetools_0.2-20 jquerylib_0.1.4
## [25] cli_3.6.3 Rmpfr_0.9-5 rlang_1.1.4 Rglpk_0.6-5.1
## [29] bit64_4.5.2 munsell_0.5.1 cachem_1.1.0 yaml_2.3.10
## [33] tools_4.4.2 Rcplex_0.3-6 rcbc_0.1.0.9001 colorspace_2.1-1
## [37] gurobi_11.0-0 assertthat_0.2.1 vctrs_0.6.5 R6_2.5.1
## [41] lifecycle_1.0.4 stringr_1.5.1 bit_4.5.0 cccp_0.3-1
## [45] bslib_0.8.0 glue_1.8.0 Rcpp_1.0.13-1 systemfonts_1.1.0
## [49] highr_0.11 xfun_0.49 rstudioapi_0.17.1 knitr_1.48
## [53] htmltools_0.5.8.1 rmarkdown_2.29 svglite_2.1.3 compiler_4.4.2