Files
HKDataManagment/base/StockDataImporter.py

500 lines
19 KiB
Python
Raw Permalink Normal View History

2025-09-03 21:41:11 +08:00
"""
存储上海证券交易股票列表数据
不确定其数据爬取规则防止 IP 被封
暂时使用该方案获取股票列表数据
下载excel,收到导入到数据库
"""
import pandas as pd
import os
import sys
import csv
import chardet # 用于检测文件编码
from pathlib import Path
from datetime import datetime
from .MySQLHelper import MySQLHelper
from .LogHelper import LogHelper
logger = LogHelper(logger_name = 'execelImport').setup()
class StockDataImporter:
"""股票数据导入工具支持CSV"""
def __init__(self, db_config: dict, column_mapping: dict, data_dir: Path):
self.db_config = db_config
self.column_mapping = column_mapping
self.data_dir = data_dir
self.df = None
self.csv_file = None
self.encoding = 'utf-8' # 默认编码
self.delimiter = ',' # 默认分隔符
self.upload_filename = None # 上传文件名
# 更新 检讨标志
def setUploadfile(self, filename: str):
"""设置需要上传的文件名"""
self.upload_filename = filename
logger.info(f"设置上传文件名为: {filename}")
def find_csv_file(self) -> Path:
"""在data文件夹中查找CSV文件"""
# 使用设置的upload_filename或默认文件名
filename = self.upload_filename if self.upload_filename else "GPLIST.csv"
# 查找所有匹配的文件
csv_files = list(self.data_dir.glob(filename))
if not csv_files:
logger.error(f"{self.data_dir} 中没有找到文件: {filename}")
return None
# 如果有多个文件,选择最新的
if len(csv_files) > 1:
csv_files.sort(key=os.path.getmtime, reverse=True)
logger.info(f"找到多个文件,选择最新的: {csv_files[0].name}")
return csv_files[0]
def validate_file(self, file_path: Path) -> bool:
"""验证CSV文件是否有效"""
try:
if not file_path.exists():
logger.error(f"CSV文件不存在: {file_path}")
return False
file_size = file_path.stat().st_size
if file_size == 0:
logger.error(f"CSV文件为空: {file_path}")
return False
return True
except Exception as e:
logger.error(f"文件验证失败: {e}")
return False
def detect_file_encoding(self, file_path: Path) -> str:
"""检测文件编码"""
try:
# 读取文件开头部分进行编码检测
with open(file_path, 'rb') as f:
raw_data = f.read(10000) # 读取前10KB
# 使用chardet检测编码
result = chardet.detect(raw_data)
encoding = result['encoding']
confidence = result['confidence']
# 常见编码替代
encoding_map = {
'GB2312': 'GBK',
'gb2312': 'GBK',
'ISO-8859-1': 'latin1',
'ascii': 'utf-8'
}
# 应用映射
encoding = encoding_map.get(encoding, encoding)
logger.info(f"检测到编码: {encoding} (置信度: {confidence:.2f})")
return encoding or 'utf-8'
except Exception as e:
logger.error(f"编码检测失败: {e}, 使用默认UTF-8")
return 'utf-8'
def detect_csv_delimiter(self, file_path: Path) -> str:
"""自动检测CSV分隔符"""
try:
# 使用检测到的编码打开文件
with open(file_path, 'r', encoding=self.encoding) as f:
# 读取前5行
lines = [f.readline() for _ in range(5) if f.readline()]
# 尝试常见分隔符
delimiters = [',', '\t', ';', '|']
delimiter_counts = {}
for delim in delimiters:
count = 0
for line in lines:
count += line.count(delim)
delimiter_counts[delim] = count
# 选择出现次数最多的分隔符
best_delim = max(delimiter_counts, key=delimiter_counts.get)
# 如果没有任何分隔符,则使用逗号
if delimiter_counts[best_delim] == 0:
logger.warning(f"无法检测到有效的分隔符,使用默认逗号分隔符")
return ','
logger.info(f"检测到分隔符: {repr(best_delim)}")
return best_delim
except Exception as e:
logger.error(f"检测分隔符失败: {e}, 使用默认逗号分隔符")
return ','
def read_csv_data(self, file_path: Path) -> bool:
"""从CSV文件读取数据"""
try:
# 1. 检测文件编码
self.encoding = self.detect_file_encoding(file_path)
# 2. 检测分隔符
self.delimiter = self.detect_csv_delimiter(file_path)
# 3. 读取CSV文件
logger.info(f"使用编码 '{self.encoding}' 和分隔符 '{self.delimiter}' 读取文件")
self.df = pd.read_csv(
file_path,
delimiter=self.delimiter,
dtype=str,
encoding=self.encoding,
on_bad_lines='warn',
quoting=csv.QUOTE_MINIMAL,
engine='python' # 更健壮的引擎
)
# 检查是否读取到数据
if self.df.empty:
logger.error("CSV文件没有包含有效数据")
return False
# 重命名列
self.df = self.df.rename(columns=self.column_mapping)
# 移除可能存在的空行
self.df = self.df.dropna(how='all')
logger.info(f"成功读取CSV数据{len(self.df)} 条记录")
return True
except UnicodeDecodeError:
# 尝试其他编码
encodings_to_try = ['GBK', 'latin1', 'ISO-8859-1', 'utf-16']
for enc in encodings_to_try:
try:
logger.warning(f"尝试使用 {enc} 编码读取文件")
self.df = pd.read_csv(
file_path,
delimiter=self.delimiter,
dtype=str,
encoding=enc
)
self.encoding = enc
logger.info(f"成功使用 {enc} 编码读取文件")
return True
except:
continue
logger.error("所有编码尝试均失败")
return False
except PermissionError:
logger.error(f"文件被占用,请关闭后重试: {file_path}")
return False
except Exception as e:
logger.error(f"读取CSV文件失败: {e}")
return False
def clean_stock_data(self) -> bool:
"""清洗股票数据(基本清洗,主要处理股票代码格式)"""
try:
# 验证股票代码格式如果存在stock_code列
if 'stock_code' in self.df.columns:
invalid_codes = self.df[~self.df['stock_code'].astype(str).str.match(r'^\d{6}$')]
if not invalid_codes.empty:
logger.warning(f"发现 {len(invalid_codes)} 条无效的股票代码")
logger.debug(f"无效代码示例: {invalid_codes['stock_code'].head().tolist()}")
logger.info("数据清洗完成")
return True
except Exception as e:
logger.error(f"数据清洗失败: {e}")
return False
def create_stocks_table(self, db: MySQLHelper) -> bool:
"""创建股票信息表(包含股票代码、中文名称、英文名称、进出标志和时间戳)"""
# 定义列类型映射
column_type_mapping = {
'stock_code': 'VARCHAR(6) PRIMARY KEY',
'stock_name_chn': 'VARCHAR(50) NULL',
'stock_name_en': 'VARCHAR(150)',
}
# 构建列定义SQL
column_definitions = []
for column_name in self.column_mapping.values():
if column_name in column_type_mapping:
column_definitions.append(f"{column_name} {column_type_mapping[column_name]}")
else:
# 对于未知列使用VARCHAR(255)
column_definitions.append(f"{column_name} VARCHAR(255)")
# 添加进出标志列
column_definitions.append("in_out TINYINT(1) DEFAULT 0 COMMENT '进出标志'")
# 添加时间戳列
column_definitions.append("created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间'")
column_definitions.append("updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间'")
# 构建完整的CREATE TABLE SQL
columns_sql = ",\n ".join(column_definitions)
create_table_sql = f"""
CREATE TABLE IF NOT EXISTS hk_stock_connect (
{columns_sql}
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='股票信息表';
"""
try:
db.execute_update(create_table_sql)
logger.info("股票信息表创建成功")
return True
except Exception as e:
logger.error(f"创建表失败: {e}")
return False
def insert_data_to_db(self, db: MySQLHelper) -> bool:
"""将数据插入数据库(处理映射的列和进出标志)"""
if self.df is None or self.df.empty:
logger.error("没有有效数据可插入")
return False
# 获取所有映射的列名
mapped_columns = list(self.column_mapping.values())
# 构建INSERT SQL语句包含in_out列设置为1
columns_sql = ", ".join(mapped_columns + ['in_out'])
placeholders = ", ".join(["%s"] * len(mapped_columns) + ["1"]) # in_out固定为1
# 构建ON DUPLICATE KEY UPDATE部分主键不更新但更新in_out字段
update_clauses = []
for column in mapped_columns:
if column != 'stock_code': # 主键不更新
update_clauses.append(f"{column} = VALUES({column})")
# 添加in_out字段更新确保在重复时也设置为1
update_clauses.append("in_out = VALUES(in_out)")
update_sql = ", ".join(update_clauses)
insert_sql = f"""
INSERT INTO hk_stock_connect (
{columns_sql}
) VALUES (
{placeholders}
)
ON DUPLICATE KEY UPDATE
{update_sql}
"""
# 准备参数列表只包含映射的列in_out由SQL固定为1
params_list = []
for _, row in self.df.iterrows():
params = []
for column in mapped_columns:
# 处理可能的NaN值
value = row[column] if column in row and pd.notna(row[column]) else None
params.append(value)
params_list.append(tuple(params))
# 批量执行插入
try:
total_rows = len(params_list)
if total_rows == 0:
logger.error("没有有效数据可插入")
return False
batch_size = 1000 # 每批插入1000条记录
logger.info(f"开始插入数据,共 {total_rows} 条记录")
# 分批插入,避免大事务问题
for i in range(0, total_rows, batch_size):
batch_params = params_list[i:i+batch_size]
affected_rows = db.execute_many(insert_sql, batch_params)
logger.info(f"已处理 {min(i+batch_size, total_rows)}/{total_rows} 条记录")
logger.info(f"成功插入/更新 {total_rows} 条记录")
return True
except Exception as e:
logger.error(f"插入数据失败: {e}")
# 记录前5个参数以帮助调试
if params_list:
logger.debug(f"前5个参数示例: {params_list[:5]}")
return False
def setInOut(self, db: MySQLHelper) -> bool:
"""设置进出标志为0且不更新updated_at字段"""
try:
# # 首先检查表是否存在
# table_check_sql = """
# SELECT COUNT(*) AS table_exists
# FROM information_schema.tables
# WHERE table_schema = DATABASE()
# AND table_name = 'hk_stock_connect'
# """
# result = db.execute_query(table_check_sql)
# if not result or result[0]['table_exists'] == 0:
# logger.warning("表 'hk_stock_connect' 不存在,跳过设置进出标志操作")
# return True
# 表存在,执行更新操作
update_sql = "UPDATE hk_stock_connect SET in_out = 1, updated_at = updated_at"
affected_rows = db.execute_update(update_sql)
logger.info(f"成功设置 {affected_rows} 条记录的进出标志为0")
return True
except Exception as e:
logger.error(f"设置进出标志失败: {e}")
return False
def verify_data_in_db(self, db: MySQLHelper, sample_size: int = 5) -> bool:
"""验证数据库中的数据"""
try:
# 检查记录总数
count_sql = "SELECT COUNT(*) AS total FROM hk_stock_connect"
result = db.execute_query(count_sql)
db_count = result[0]['total'] if result else 0
logger.info(f"数据库中共有 {db_count} 条记录")
# 获取映射的列名用于显示
mapped_columns = list(self.column_mapping.values())
# 构建查询列使用所有映射的列和in_out字段
select_columns = mapped_columns + ['in_out'] if mapped_columns else ["*"]
columns_sql = ", ".join(select_columns)
# 随机抽样检查
sample_sql = f"""
SELECT {columns_sql}
FROM hk_stock_connect
ORDER BY RAND()
LIMIT {sample_size}
"""
samples = db.execute_query(sample_sql)
logger.info("\n随机抽样记录:")
for idx, sample in enumerate(samples, 1):
# 动态构建日志消息显示所有映射的列和in_out字段
sample_info = []
for column in select_columns:
if column in sample and sample[column] is not None:
sample_info.append(f"{column}: {sample[column]}")
logger.info(f"{idx}. {' | '.join(sample_info) if sample_info else 'No data to display'}")
return True
except Exception as e:
logger.error(f"数据验证失败: {e}")
return False
def run_import(self) -> bool:
"""执行完整的导入流程"""
logger.info(f"开始导入股票数据,数据目录: {self.data_dir}")
start_time = datetime.now()
# 1. 查找CSV文件
csv_file = self.find_csv_file()
if not csv_file:
return False
# 2. 验证文件
if not self.validate_file(csv_file):
return False
# 3. 读取CSV数据
if not self.read_csv_data(csv_file):
return False
# 4. 清洗数据
if not self.clean_stock_data():
return False
# 显示前5条数据动态显示可用列
logger.info("\n前5条股票数据:")
for i, row in self.df.head().iterrows():
# 动态构建显示信息
display_info = []
if 'stock_code' in row:
display_info.append(f"代码: {row['stock_code']}")
if 'stock_name_chn' in row:
display_info.append(f"名称: {row['stock_name_chn']}")
if 'stock_name_en' in row:
display_info.append(f"英文: {row['stock_name_en']}")
logger.info(f"{i+1}. {' | '.join(display_info)}")
# 5. 连接数据库并导入
try:
with MySQLHelper(**self.db_config) as db:
# 5.1 创建表
if not self.create_stocks_table(db):
return False
# 更新检讨结果
self.setInOut(db)
# 5.2 插入数据
if not self.insert_data_to_db(db):
return False
# 5.3 验证数据
if not self.verify_data_in_db(db):
return False
except Exception as e:
logger.error(f"数据库操作异常: {e}")
return False
# 计算执行时间
duration = datetime.now() - start_time
logger.info(f"数据处理成功完成! 总耗时: {duration.total_seconds():.2f}")
return True
if __name__ == "__main__":
# 数据库配置
db_config = {
'host': 'localhost',
'user': 'root',
'password': 'bzskmysql',
'database': 'hk_kline_1d'
}
# 列映射配置
COLUMN_MAPPING = {
'证券代码': 'stock_code',
'中文简称': 'stock_name_chn',
'英文简称': 'stock_name_en',
}
# 获取当前脚本所在目录
current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd()
# 设置数据目录
DATA_DIR = current_dir / "data"
# 确保data目录存在
DATA_DIR.mkdir(exist_ok=True, parents=True)
# 安装依赖 (如果chardet未安装)
try:
import chardet
except ImportError:
logger.info("安装chardet库以支持编码检测...")
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install", "chardet"])
import chardet
# 创建导入器并执行导入(使用新的参数顺序)
importer = StockDataImporter(db_config, COLUMN_MAPPING, DATA_DIR)
# 2.2w设置上传文件名(可以注释掉使用默认文件名)
importer.setUploadfile("港股通标的证券名单.csv")
if importer.run_import():
logger.info("股票数据导入成功!")
else:
logger.error("股票数据导入失败,请检查日志了解详情")