Source code for pymer4.io
1__all__ = ["save_model", "load_model", "load_dataset"]
2
3import polars as pl
4import os
5from rpy2.robjects.packages import importr
6from joblib import dump, load
7from importlib import resources
8
9base = importr("base")
10
11
[docs]
12def save_model(model, filepath, **kwargs):
13 """
14 Function for saving pymer4 models. All models are saved using joblib.dump files so
15 filepath extensions should end with .joblib. For Lmer models an additional
16 filepath.robj file will be created to retain all R objects.
17
18 Args:
19 model (pymer4.models): an instance of a pymer4 model
20 filepath (str): full filepath string ending .joblib
21 kwargs: optional keyword arguments to joblib.dump
22 """
23
24 filepath = str(filepath)
25 if not filepath.endswith(".joblib"):
26 raise IOError("filepath must end with .joblib")
27
28 rds_file = filepath.replace(".joblib", ".rds")
29
30 # Save the python object
31 dump(model, filepath, **kwargs)
32 assert os.path.exists(filepath)
33
34 # Now deal with model object in R if needed
35 base.saveRDS(model.r_model, rds_file)
36 assert os.path.exists(rds_file)
37
38
[docs]
39def load_model(filepath):
40 """
41 Function for loading pymer4 models. A file path ending in .joblib should be provided. For Lmer models an additional filepath.robj should be located in the same directory.
42
43 Args:
44 model (pymer4.models): an instance of a pymer4 model
45 filepath (str): full filepath string ending with .joblib
46 """
47
48 filepath = str(filepath)
49 if not filepath.endswith(".joblib"):
50 raise IOError("filepath must end with .joblib")
51
52 rds_file = filepath.replace(".joblib", ".rds")
53
54 # Load python object
55 model = load(filepath)
56
57 # Now deal with model object in R if needed
58 model.r_model = base.readRDS(rds_file)
59 return model
60
61
[docs]
62def load_dataset(name):
63 """Loads csv file included with package as a polars DataFrame"""
64
65 valid_names = [
66 "gammas",
67 "mtcars",
68 "sample_data",
69 "sleep",
70 "sleepmissing",
71 "titanic",
72 "titanic_train",
73 "titanic_test",
74 "poker",
75 "chickweight",
76 "credit",
77 "credit-mini",
78 "advertising",
79 "penguins",
80 ]
81
82 if name not in valid_names:
83 raise ValueError(f"Dataset name must be one of: {valid_names}")
84
85 fpath = resources.files("pymer4").joinpath(f"resources/{name}.csv")
86 return pl.read_csv(fpath)