该模块定义了ConfigManager类,用于管理和操作YAML格式的配置文件。
主要功能包括:创建配置文件、读取配置、获取和更新配置参数。
import os
import yaml # pyyaml
# ConfigManager.py
# 该模块定义了ConfigManager类,用于管理和操作YAML格式的配置文件。
# 主要功能包括:创建配置文件、读取配置、获取和更新配置参数。
# 这对于管理应用程序设置和参数非常有用,特别是在实验和开发环境中。
class ConfigManager:
"""
ConfigManager 类用于管理YAML配置文件。
它允许加载、获取、更新配置文件中的参数。
"""
def __init__(self, config_path = 'config.yml'):
"""
初始化ConfigManager实例。
:param config_path: 配置文件的路径。
"""
self.config_path = config_path
# 检查文件是否存在,如果不存在则创建一个空文件
if not os.path.exists(self.config_path):
with open(self.config_path, 'w', encoding='utf-8') as file:
yaml.dump({}, file)
self.config = self.load_config()
def load_config(self):
"""
加载YAML配置文件。
如果文件不存在或无法解析,则抛出相应的异常。正常情况下会自动创建配置文件
:return: 返回配置文件的内容。
"""
try:
with open(self.config_path, 'r', encoding='utf-8') as file:
return yaml.safe_load(file)
except FileNotFoundError:
raise FileNotFoundError(f"Config file {self.config_path} not found.")
except yaml.YAMLError as exc:
raise RuntimeError(f"Error while parsing YAML file: {exc}")
def get_param(self, section, key, default=None):
"""
获取配置文件中的参数值。
:param section: 配置文件中的部分(如'experiment')。
:param key: 部分中的键。
:param default: 如果键不存在,则返回的默认值。
:return: 返回键对应的值,如果不存在则返回默认值。
"""
if section in self.config and key in self.config[section]:
return self.config[section][key]
else:
if default is not None:
self.update_param(section, key, default)
return default
else:
raise KeyError(f"Param '{key}' not found in section '{section}', and no default value provided.")
def update_param(self, section, key, value):
"""
更新配置文件中的参数。
:param section: 配置文件中的部分。
:param key: 部分中的键。
:param value: 要更新的值。
"""
if section not in self.config:
self.config[section] = {}
self.config[section][key] = value
with open(self.config_path, 'w', encoding='utf-8') as file:
yaml.dump(self.config, file, allow_unicode=True, default_flow_style=False)
if __name__ == '__main__':
# 使用示例
config_manager = ConfigManager("config.yml")
learning_rate = config_manager.get_param('experiment', 'learning_rate', 1.0)
print(f"Learning Rate: {learning_rate}")
# 更新参数示例
config_manager.update_param('experiment', 'learning_rate', 0.02)