1from .base import enable_logging, requires_fit, model
2from ..tidystats.broom import tidy
3from ..tidystats.lmerTest import ranef, lmer as lmer_, bootMer
4from ..tidystats.multimodel import coef
5from ..tidystats.easystats import get_param_names, is_converged
6from ..tidystats.tables import summary_lmm_table
7from ..expressions import logit2odds
8from polars import DataFrame, col
9import polars.selectors as cs
10from rpy2.robjects import NULL, NA_Real
11
12
[docs]
13class lmer(model):
14 """Linear mixed effects model estimated via ML/REML. Inherits from ``lm``.
15
16 This class implements linear mixed effects models using Maximum Likelihood or
17 Restricted Maximum Likelihood estimation. It extends the base linear model class
18 to handle random effects and nested data structures.
19
20 Args:
21 formula (str): R-style formula specifying the model, including random effects
22 data (DataFrame): Input data for the model
23 """
24
25 def __init__(self, formula, data, **kwargs):
26 """Initialize the linear mixed effects model.
27
28 Args:
29 formula (str): R-style formula specifying the model, including random effects
30 data (DataFrame): Input data for the model
31 """
32 super().__init__(formula, data, **kwargs)
33 self._r_func = lmer_
34 self._summary_func = summary_lmm_table
35 self.fixef = None
36 self.ranef = None
37 self.ranef_var = None
38 self.convergence_status = None
39 self._REML = self._init_kwargs.get("REML", True)
40
41 def _handle_rfx(self, **kwargs):
42 """Sets `.ranef_var` using ``broom.mixed::tidy()`` and ``lme4::ranef()`` and ``lme4::coef()`` to get random effects and BLUPs. Manually exponentiates random effects if ``exponentiate=True`` since ``broom.mixed::tidy()`` does not do this."""
43
44 self.ranef_var = tidy(
45 self.r_model, effects="ran_pars", conf_int=True, **kwargs
46 ).drop("effect", strict=False)
47 self.ranef = ranef(self.r_model)
48 self.fixef = coef(self.r_model)
49
50 # Ensure multiple rfx are returned as a dict
51 if isinstance(self.fixef, list):
52 fixed_names, random_names = get_param_names(self.r_model)
53 self.fixef = dict(zip(random_names.keys(), self.fixef))
54 self.ranef = dict(zip(random_names.keys(), self.ranef))
55
56 # Exponentiate params if requested
57 exponentiate = kwargs.get("exponentiate", False)
58 if exponentiate:
59 if isinstance(self.fixef, dict):
60 self.fixef = {
61 k: v.with_columns(col("level"), logit2odds(cs.exclude("level")))
62 for k, v in self.fixef.items()
63 }
64 else:
65 self.fixef = self.fixef.with_columns(
66 col("level"), logit2odds(cs.exclude("level"))
67 )
68 if isinstance(self.ranef, dict):
69 self.ranef = {
70 k: v.with_columns(col("level"), logit2odds(cs.exclude("level")))
71 for k, v in self.ranef.items()
72 }
73 else:
74 self.ranef = self.ranef.with_columns(
75 col("level"), logit2odds(cs.exclude("level"))
76 )
77
78 def _bootstrap(
79 self,
80 nboot,
81 save_boots,
82 conf_method="perc",
83 parallel="multicore",
84 ncpus=4,
85 conf_level=0.95,
86 **kwargs,
87 ):
88 """Get bootstrapped estimates of model parameters using `lme4's confint.merMod <https://www.rdocumentation.org/packages/lme4/versions/1.1-36/topics/confint.merMod>`_. Unlike with `lm()`, we don't use `easystats` functions because they don't return the full bootstrap distribution for rfx, only ffx. We use `tidy` to summarize the bootstrap distributions and can therefore can use all the `conf_method` that it supports (e.g. `"perc"`, `"bca"`, `"norm"`, `"basic"`).
89
90 Args:
91 conf_method (str, optional): Type of bootstrap confidence intervals. Defaults to "perc"
92 nboot (int, optional): Number of bootstrap samples. Defaults to 1000
93 parallel (str, optional): Parallelization method. Defaults to "multicore"
94 ncpus (int, optional): Number of CPUs to use. Defaults to 4
95 conf_level (float, optional): Confidence level for intervals. Defaults to 0.95
96 save_boots (bool, optional): Whether to save bootstrap samples. Defaults to True
97 **kwargs: Additional arguments passed to confint
98
99 Returns:
100 tuple: (fix_cis, rfx_cis) - Fixed effects CIs and random effects CIs as polars DataFrames
101 """
102 cis, boots = bootMer(
103 self.r_model,
104 nsim=nboot,
105 conf_level=conf_level,
106 conf_method=conf_method,
107 parallel=parallel,
108 ncpus=ncpus,
109 save_boots=save_boots,
110 **kwargs,
111 )
112 self.cis = cis
113
114 # Fixed CIs
115 fixed_names = self.params["term"].to_list()
116 fixed_lower = (
117 cis.filter(col("term").is_in(fixed_names)).select("conf_low").to_series()
118 )
119 fixed_upper = (
120 cis.filter(col("term").is_in(fixed_names)).select("conf_high").to_series()
121 )
122 self.result_fit = self.result_fit.with_columns(
123 conf_low=fixed_lower,
124 conf_high=fixed_upper,
125 )
126
127 # Drop fixed-effect rows and split out term col to term and group cols
128 ranef_cis = (
129 cis.filter(~col("term").is_in(fixed_names))
130 .with_columns(
131 col("term")
132 .str.split_exact("___", 2)
133 .explode()
134 .struct.rename_fields(["term", "group"])
135 .struct.unnest()
136 )
137 .select("group", "term", "conf_low", "conf_high")
138 )
139 self.ranef_var = self.ranef_var.drop("conf_low", "conf_high").join(
140 ranef_cis, on=["term", "group"]
141 )
142
143 if save_boots:
144 self.result_boots = boots
145
146 @enable_logging
147 def fit(
148 self,
149 summary=False,
150 conf_method="satterthwaite",
151 nboot=1000,
152 save_boots=True,
153 parallel="multicore",
154 ncpus=4,
155 conf_type="perc",
156 bootMer_kwargs={},
157 **kwargs,
158 ):
159 """Fit a linear mixed effects model using ``lmer()`` in R with Satterthwaite degrees of freedom and p-values calculated using ``lmerTest``.
160
161 Args:
162 summary (bool, optional): Whether to return the model summary. Defaults to False
163 conf_method (str, optional): Method for confidence interval calculation. Defaults to ``"satterthwaite"``. Alternatively, ``"boot"`` for bootstrap CIs.
164 nboot (int, optional): Number of bootstrap samples. Defaults to 1000
165 parallel (str, optional): Parallelization for bootstrapping. Defaults to "multicore"
166 ncpus (int, optional): Number of cores to use for parallelization. Defaults to 4
167 conf_type (str, optional): Type of confidence interval to calculate. Defaults to "perc"
168
169 Returns:
170 GT, optional: Model summary if ``summary=True``
171 """
172
173 # Use super to get fixed effects via easystats::model_parameters()
174 if conf_method == "boot":
175 if self.family is None:
176 default_conf_method = "satterthwaite"
177 else:
178 default_conf_method = "wald"
179 super().fit(
180 conf_method=conf_method if conf_method != "boot" else default_conf_method,
181 effects="fixed",
182 ci_random=False,
183 parallel=parallel,
184 ncpus=ncpus,
185 conf_type=conf_type,
186 **kwargs,
187 )
188
189 # Store the conf_method in the fit_kwargs since we overwrite it in the super call
190 self._fit_kwargs["conf_method"] = conf_method
191
192 # Get random effects
193 self._handle_rfx(**kwargs)
194
195 if conf_method == "boot":
196 self._bootstrap(
197 nboot=nboot,
198 save_boots=save_boots,
199 conf_method=conf_type,
200 parallel=parallel,
201 ncpus=ncpus,
202 **bootMer_kwargs,
203 )
204
205 # Handle convergence & singularity warnings
206 did_converge, message = is_converged(self.r_model)
207 self.convergence_status = message
208 if not did_converge:
209 self.r_console.append(message)
210
211 if summary:
212 return self.summary()
213
214 @enable_logging
215 def anova(
216 self,
217 summary=False,
218 auto_ss_3=True,
219 jointtest_kwargs={"mode": "satterthwaite", "lmer_df": "satterthwaite"},
220 anova_kwargs={},
221 ):
222 """Calculate a Type-III ANOVA table for the model using ``joint_tests()`` in R.
223
224 Args:
225 summary (bool): whether to return the ANOVA summary. Defaults to False
226 auto_ss_3 (bool): whether to automatically use balanced contrasts when calculating the result via `joint_tests()`. When False, will use the contrasts specified with `set_contrasts()` which defaults to `"contr.treatment"` and R's `anova()` function; Default is True.
227 jointtest_kwargs (dict): additional arguments to pass to `joint_tests()` Defaults to using Satterthwaite degrees of freedom
228 anova_kwargs (dict): additional arguments to pass to `anova()`
229 """
230 super().anova(
231 auto_ss_3=auto_ss_3,
232 jointtest_kwargs=jointtest_kwargs,
233 anova_kwargs=anova_kwargs,
234 )
235 if summary:
236 return self.summary_anova()
237
238 @enable_logging
239 @requires_fit
240 def emmeans(self, marginal_var, by=None, p_adjust="sidak", **kwargs):
241 """Compute marginal means and/or contrasts between factor levels. ``marginal_var`` is the predictor whose levels will have means or contrasts computed. ``by`` is an optional predictor to marginalize over. If ``contrasts`` is not specified, only marginal means are returned
242
243 Args:
244 marginal_var (str): name of predictor to compute means or contrasts for
245 by (str/list): additional predictors to marginalize over
246 contrasts (str | 'pairwise' | 'poly' | dict | None, optional): how to specify comparison within `marginal_var`. Defaults to None.
247 p_adjust (str): multiple comparisons adjustment method. One of: none, tukey (default), bonf, sidak, fdr, holm, dunnet, mvt (monte-carlo multi-variate T, aka exact tukey/dunnet).
248
249 Returns:
250 DataFrame: Table of marginal means or contrasts
251 """
252
253 return super().emmeans(
254 marginal_var,
255 by,
256 mode="satterthwaite",
257 lmer_df="satterthwaite",
258 lmerTest_limit=999999,
259 p_adjust=p_adjust,
260 **kwargs,
261 )
262
263 def predict(self, data: DataFrame, use_rfx=True, **kwargs):
264 """Make predictions using new data.
265
266 Args:
267 data (DataFrame): Input data for predictions
268 use_rfx (bool, optional): Whether to include random effects in predictions. Defaults to True. Equivalent to ``re.form = NULL`` in R if True, ``re.form = NA`` if False
269 **kwargs: Additional arguments passed to predict function
270
271 Returns:
272 ndarray: Predicted values
273 """
274 re_form = NULL if use_rfx else NA_Real
275 return super().predict(data, re_form=re_form, **kwargs)
276
277 @requires_fit
278 def simulate(self, nsim: int = 1, use_rfx=True, **kwargs):
279 """Simulate values from the fitted model.
280
281 Args:
282 nsim (int, optional): Number of simulations to run. Defaults to 1
283 use_rfx (bool, optional): Whether to include random effects in simulations. Defaults to True.
284 Equivalent to ``re.form = NULL`` in R if True, ``re.form = NA`` if False
285 **kwargs: Additional arguments passed to simulate function
286
287 Returns:
288 DataFrame: Simulated values with the same number of rows as the original data
289 and columns equal to nsim
290 """
291 re_form = NULL if use_rfx else NA_Real
292 return super().simulate(nsim, re_form=re_form, **kwargs)