该模块定义了ConfigManager类,用于管理和操作YAML格式的配置文件。
主要功能包括:创建配置文件、读取配置、获取和更新配置参数。

import os
import yaml # pyyaml

# ConfigManager.py
# 该模块定义了ConfigManager类,用于管理和操作YAML格式的配置文件。
# 主要功能包括:创建配置文件、读取配置、获取和更新配置参数。
# 这对于管理应用程序设置和参数非常有用,特别是在实验和开发环境中。

class ConfigManager:
    """
    ConfigManager 类用于管理YAML配置文件。
    它允许加载、获取、更新配置文件中的参数。
    """

    def __init__(self, config_path = 'config.yml'):
        """
        初始化ConfigManager实例。

        :param config_path: 配置文件的路径。
        """
        self.config_path = config_path
        # 检查文件是否存在,如果不存在则创建一个空文件
        if not os.path.exists(self.config_path):
            with open(self.config_path, 'w', encoding='utf-8') as file:
                yaml.dump({}, file)
        self.config = self.load_config()

    def load_config(self):
        """
        加载YAML配置文件。
        如果文件不存在或无法解析,则抛出相应的异常。正常情况下会自动创建配置文件

        :return: 返回配置文件的内容。
        """
        try:
            with open(self.config_path, 'r', encoding='utf-8') as file:
                return yaml.safe_load(file)
        except FileNotFoundError:
            raise FileNotFoundError(f"Config file {self.config_path} not found.")
        except yaml.YAMLError as exc:
            raise RuntimeError(f"Error while parsing YAML file: {exc}")

    def get_param(self, section, key, default=None):
        """
        获取配置文件中的参数值。

        :param section: 配置文件中的部分(如'experiment')。
        :param key: 部分中的键。
        :param default: 如果键不存在,则返回的默认值。
        :return: 返回键对应的值,如果不存在则返回默认值。
        """
        if section in self.config and key in self.config[section]:
            return self.config[section][key]
        else:
            if default is not None:
                self.update_param(section, key, default)
                return default
            else:
                raise KeyError(f"Param '{key}' not found in section '{section}', and no default value provided.")

    def update_param(self, section, key, value):
        """
        更新配置文件中的参数。

        :param section: 配置文件中的部分。
        :param key: 部分中的键。
        :param value: 要更新的值。
        """
        if section not in self.config:
            self.config[section] = {}
        self.config[section][key] = value
        with open(self.config_path, 'w', encoding='utf-8') as file:
            yaml.dump(self.config, file, allow_unicode=True, default_flow_style=False)

if __name__ == '__main__':

    # 使用示例
    config_manager = ConfigManager("config.yml")
    learning_rate = config_manager.get_param('experiment', 'learning_rate', 1.0)
    print(f"Learning Rate: {learning_rate}")

    # 更新参数示例
    config_manager.update_param('experiment', 'learning_rate', 0.02)