在量化投资领域,建立一个高效、稳定的数据处理和分析工具链是成功的基础。本项目旨在帮助学习者构建从数据采集、清洗、存储到特征提取、分析的完整流程,为后续策略开发奠定坚实基础。
学习目标:
class DataSource:
"""数据源基类,定义标准接口"""
def __init__(self, config):
"""
初始化数据源
参数:
config (dict): 配置信息,包含连接参数、认证信息等
"""
self.config = config
self.connection = None
def connect(self):
"""建立与数据源的连接"""
raise NotImplementedError
def disconnect(self):
"""断开与数据源的连接"""
raise NotImplementedError
def get_data(self, query_params):
"""
获取数据
参数:
query_params (dict): 查询参数
返回:
pd.DataFrame: 查询结果
"""
raise NotImplementedError
def check_health(self):
"""检查数据源连接状态"""
raise NotImplementedError
class ETLPipeline:
"""ETL (Extract-Transform-Load) 流程框架"""
def __init__(self, data_sources, transformers, data_store):
"""
初始化ETL流程
参数:
data_sources (list): 数据源列表
transformers (list): 数据转换器列表
data_store (DataStore): 数据存储对象
"""
self.data_sources = data_sources
self.transformers = transformers
self.data_store = data_store
self.logger = self._setup_logger()
def _setup_logger(self):
"""设置日志记录器"""
# 日志配置代码
pass
def extract(self, query_params):
"""
从数据源提取数据
参数:
query_params (dict): 查询参数
返回:
dict: 键为数据源ID,值为相应的DataFrame
"""
results = {}
for source in self.data_sources:
try:
source.connect()
data = source.get_data(query_params)
results[source.id] = data
self.logger.info(f"从数据源 {source.id} 成功提取 {len(data)} 条记录")
except Exception as e:
self.logger.error(f"从数据源 {source.id} 提取数据时发生错误: {str(e)}")
finally:
source.disconnect()
return results
def transform(self, data_dict):
"""
转换数据
参数:
data_dict (dict): 提取的原始数据
返回:
dict: 转换后的数据
"""
transformed = {}
for source_id, data in data_dict.items():
curr_data = data.copy()
for transformer in self.transformers:
try:
curr_data = transformer.transform(curr_data)
self.logger.info(f"使用转换器 {transformer.name} 成功转换数据源 {source_id} 的数据")
except Exception as e:
self.logger.error(f"转换数据源 {source_id} 的数据时发生错误: {str(e)}")
transformed[source_id] = curr_data
return transformed
def load(self, transformed_data):
"""
加载数据到存储系统
参数:
transformed_data (dict): 转换后的数据
返回:
bool: 操作是否成功
"""
try:
self.data_store.connect()
for source_id, data in transformed_data.items():
self.data_store.save(data, source_id)
self.logger.info(f"成功加载数据源 {source_id} 的 {len(data)} 条记录到存储系统")
return True
except Exception as e:
self.logger.error(f"加载数据到存储系统时发生错误: {str(e)}")
return False
finally:
self.data_store.disconnect()
def run(self, query_params):
"""
运行完整ETL流程
参数:
query_params (dict): 查询参数
返回:
bool: 操作是否成功
"""
extracted = self.extract(query_params)
if not extracted:
self.logger.warning("未提取到任何数据,ETL流程中止")
return False
transformed = self.transform(extracted)
if not transformed:
self.logger.warning("数据转换失败,ETL流程中止")
return False
loaded = self.load(transformed)
return loaded
def check_missing_values(df, threshold=0.05):
"""
检查数据框中的缺失值
参数:
df (pd.DataFrame): 待检查的数据框
threshold (float): 可接受的缺失值比例阈值
返回:
dict: 每列缺失值统计和整体评估结果
"""
missing_counts = df.isnull().sum()
missing_ratio = missing_counts / len(df)
columns_above_threshold = missing_ratio[missing_ratio > threshold].index.tolist()
return {
'missing_counts': missing_counts.to_dict(),
'missing_ratio': missing_ratio.to_dict(),
'columns_above_threshold': columns_above_threshold,
'passed': len(columns_above_threshold) == 0
}
def check_duplicates(df, subset=None):
"""
检查数据框中的重复记录
参数:
df (pd.DataFrame): 待检查的数据框
subset (list): 用于识别重复的列子集
返回:
dict: 重复记录统计和示例
"""
duplicates = df.duplicated(subset=subset, keep='first')
duplicate_indices = duplicates[duplicates].index.tolist()
return {
'duplicate_count': sum(duplicates),
'duplicate_ratio': sum(duplicates) / len(df) if len(df) > 0 else 0,
'duplicate_indices': duplicate_indices[:10], # 只返回前10个示例
'passed': sum(duplicates) == 0
}
def check_data_range(df, numeric_ranges=None, categorical_values=None):
"""
检查数据值是否在预期范围内
参数:
df (pd.DataFrame): 待检查的数据框
numeric_ranges (dict): 数值列的有效范围,格式为 {'column_name': (min, max)}
categorical_values (dict): 分类列的有效值集合,格式为 {'column_name': set(valid_values)}
返回:
dict: 范围检查结果
"""
results = {'numeric_columns': {}, 'categorical_columns': {}, 'passed': True}
# 检查数值范围
if numeric_ranges:
for column, (min_val, max_val) in numeric_ranges.items():
if column in df.columns:
out_of_range = df[(df[column] < min_val) | (df[column] > max_val)]
results['numeric_columns'][column] = {
'out_of_range_count': len(out_of_range),
'out_of_range_ratio': len(out_of_range) / len(df) if len(df) > 0 else 0,
'min_value': df[column].min(),
'max_value': df[column].max(),
'passed': len(out_of_range) == 0
}
if len(out_of_range) > 0:
results['passed'] = False
# 检查分类值
if categorical_values:
for column, valid_values in categorical_values.items():
if column in df.columns:
invalid_values = df[~df[column].isin(valid_values)]
results['categorical_columns'][column] = {
'invalid_count': len(invalid_values),
'invalid_ratio': len(invalid_values) / len(df) if len(df) > 0 else 0,
'unique_values': df[column].unique().tolist(),
'passed': len(invalid_values) == 0
}
if len(invalid_values) > 0:
results['passed'] = False
return results
def check_data_consistency(df, consistency_rules):
"""
检查数据一致性规则
参数:
df (pd.DataFrame): 待检查的数据框
consistency_rules (list): 一致性规则列表,每条规则是一个函数,接受df作为参数并返回布尔值
返回:
dict: 一致性检查结果
"""
results = {'rule_results': {}, 'passed': True}
for i, rule in enumerate(consistency_rules):
rule_name = getattr(rule, '__name__', f'rule_{i}')
try:
rule_passed = rule(df)
results['rule_results'][rule_name] = rule_passed
if not rule_passed:
results['passed'] = False
except Exception as e:
results['rule_results'][rule_name] = {
'error': str(e),
'passed': False
}
results['passed'] = False
return results
def run_data_quality_checks(df, checks_config):
"""
运行一组数据质量检查
参数:
df (pd.DataFrame): 待检查的数据框
checks_config (dict): 检查配置
返回:
dict: 所有质量检查的综合结果
"""
results = {}
overall_passed = True
# 缺失值检查
if 'missing_values' in checks_config:
threshold = checks_config['missing_values'].get('threshold', 0.05)
missing_check = check_missing_values(df, threshold)
results['missing_values'] = missing_check
overall_passed &= missing_check['passed']
# 重复记录检查
if 'duplicates' in checks_config:
subset = checks_config['duplicates'].get('subset', None)
duplicate_check = check_duplicates(df, subset)
results['duplicates'] = duplicate_check
overall_passed &= duplicate_check['passed']
# 数据范围检查
if 'data_range' in checks_config:
numeric_ranges = checks_config['data_range'].get('numeric_ranges', None)
categorical_values = checks_config['data_range'].get('categorical_values', None)
range_check = check_data_range(df, numeric_ranges, categorical_values)
results['data_range'] = range_check
overall_passed &= range_check['passed']
# 数据一致性检查
if 'consistency' in checks_config:
consistency_rules = checks_config['consistency'].get('rules', [])
consistency_check = check_data_consistency(df, consistency_rules)
results['consistency'] = consistency_check
overall_passed &= consistency_check['passed']
results['overall_passed'] = overall_passed
return results
class DataStore:
"""数据存储基类,定义标准接口"""
def __init__(self, config):
"""
初始化数据存储
参数:
config (dict): 配置信息,包含连接参数、存储路径等
"""
self.config = config
self.connection = None
def connect(self):
"""建立与存储系统的连接"""
raise NotImplementedError
def disconnect(self):
"""断开与存储系统的连接"""
raise NotImplementedError
def save(self, data, dataset_id):
"""
保存数据
参数:
data (pd.DataFrame): 待保存的数据
dataset_id (str): 数据集标识符
返回:
bool: 操作是否成功
"""
raise NotImplementedError
def load(self, dataset_id, query_params=None):
"""
加载数据
参数:
dataset_id (str): 数据集标识符
query_params (dict): 查询参数
返回:
pd.DataFrame: 加载的数据
"""
raise NotImplementedError
def delete(self, dataset_id, query_params=None):
"""
删除数据
参数:
dataset_id (str): 数据集标识符
query_params (dict): 查询参数,指定要删除的记录
返回:
bool: 操作是否成功
"""
raise NotImplementedError
def list_datasets(self):
"""
列出所有数据集
返回:
list: 数据集标识符列表
"""
raise NotImplementedError
# 示例分层数据存储结构
class TimeSeriesDataStore(DataStore):
"""
针对时间序列数据的分层存储实现
存储结构:
- 原始数据层 (raw)
- 清洗数据层 (cleaned)
- 特征数据层 (features)
- 分析结果层 (results)
"""
def __init__(self, config):
super().__init__(config)
self.base_path = config.get('base_path', './data')
self.layers = {
'raw': os.path.join(self.base_path, 'raw'),
'cleaned': os.path.join(self.base_path, 'cleaned'),
'features': os.path.join(self.base_path, 'features'),
'results': os.path.join(self.base_path, 'results')
}
self._ensure_directories()
def _ensure_directories(self):
"""确保存储目录存在"""
for path in self.layers.values():
os.makedirs(path, exist_ok=True)
def connect(self):
"""对于文件系统存储,连接操作简化为检查目录访问权限"""
for path in self.layers.values():
if not os.access(path, os.W_OK):
raise PermissionError(f"无法写入目录: {path}")
return True
def disconnect(self):
"""断开连接,对于文件系统无特殊操作"""
return True
def save(self, data, dataset_id, layer='raw'):
"""
保存数据到指定层
参数:
data (pd.DataFrame): 待保存的数据
dataset_id (str): 数据集标识符
layer (str): 目标存储层 ('raw', 'cleaned', 'features', 'results')
返回:
bool: 操作是否成功
"""
if layer not in self.layers:
raise ValueError(f"无效的存储层: {layer}")
file_path = os.path.join(self.layers[layer], f"{dataset_id}.parquet")
try:
data.to_parquet(file_path, index=True)
return True
except Exception as e:
print(f"保存数据失败: {str(e)}")
return False
def load(self, dataset_id, layer='raw', query_params=None):
"""
从指定层加载数据
参数:
dataset_id (str): 数据集标识符
layer (str): 目标存储层 ('raw', 'cleaned', 'features', 'results')
query_params (dict): 包含过滤条件
返回:
pd.DataFrame: 加载的数据
"""
if layer not in self.layers:
raise ValueError(f"无效的存储层: {layer}")
file_path = os.path.join(self.layers[layer], f"{dataset_id}.parquet")
if not os.path.exists(file_path):
raise FileNotFoundError(f"找不到数据集文件: {file_path}")
# 基本加载
data = pd.read_parquet(file_path)
# 应用查询过滤
if query_params:
# 日期范围过滤
if 'start_date' in query_params and 'date_column' in query_params:
date_col = query_params['date_column']
data = data[data[date_col] >= query_params['start_date']]
if 'end_date' in query_params and 'date_column' in query_params:
date_col = query_params['date_column']
data = data[data[date_col] <= query_params['end_date']]
# 列选择
if 'columns' in query_params:
columns = [col for col in query_params['columns'] if col in data.columns]
data = data[columns]
return data
def list_datasets(self, layer='raw'):
"""
列出指定层的所有数据集
参数:
layer (str): 目标存储层 ('raw', 'cleaned', 'features', 'results')
返回:
list: 数据集标识符列表
"""
if layer not in self.layers:
raise ValueError(f"无效的存储层: {layer}")
path = self.layers[layer]
files = [f for f in os.listdir(path) if f.endswith('.parquet')]
return [os.path.splitext(f)[0] for f in files]