Source code for pymer4.io

 1__all__ = ["save_model", "load_model", "load_dataset"]
 2
 3import polars as pl
 4import os
 5from rpy2.robjects.packages import importr
 6from joblib import dump, load
 7from importlib import resources
 8
 9base = importr("base")
10
11
[docs] 12def save_model(model, filepath, **kwargs): 13 """ 14 Function for saving pymer4 models. All models are saved using joblib.dump files so 15 filepath extensions should end with .joblib. For Lmer models an additional 16 filepath.robj file will be created to retain all R objects. 17 18 Args: 19 model (pymer4.models): an instance of a pymer4 model 20 filepath (str): full filepath string ending .joblib 21 kwargs: optional keyword arguments to joblib.dump 22 """ 23 24 filepath = str(filepath) 25 if not filepath.endswith(".joblib"): 26 raise IOError("filepath must end with .joblib") 27 28 rds_file = filepath.replace(".joblib", ".rds") 29 30 # Save the python object 31 dump(model, filepath, **kwargs) 32 assert os.path.exists(filepath) 33 34 # Now deal with model object in R if needed 35 base.saveRDS(model.r_model, rds_file) 36 assert os.path.exists(rds_file)
37 38
[docs] 39def load_model(filepath): 40 """ 41 Function for loading pymer4 models. A file path ending in .joblib should be provided. For Lmer models an additional filepath.robj should be located in the same directory. 42 43 Args: 44 model (pymer4.models): an instance of a pymer4 model 45 filepath (str): full filepath string ending with .joblib 46 """ 47 48 filepath = str(filepath) 49 if not filepath.endswith(".joblib"): 50 raise IOError("filepath must end with .joblib") 51 52 rds_file = filepath.replace(".joblib", ".rds") 53 54 # Load python object 55 model = load(filepath) 56 57 # Now deal with model object in R if needed 58 model.r_model = base.readRDS(rds_file) 59 return model
60 61
[docs] 62def load_dataset(name): 63 """Loads csv file included with package as a polars DataFrame""" 64 65 valid_names = [ 66 "gammas", 67 "mtcars", 68 "sample_data", 69 "sleep", 70 "sleepmissing", 71 "titanic", 72 "titanic_train", 73 "titanic_test", 74 "poker", 75 "chickweight", 76 "credit", 77 "credit-mini", 78 "advertising", 79 "penguins", 80 ] 81 82 if name not in valid_names: 83 raise ValueError(f"Dataset name must be one of: {valid_names}") 84 85 fpath = resources.files("pymer4").joinpath(f"resources/{name}.csv") 86 return pl.read_csv(fpath)