2025-08-15 23:52:12 +08:00
|
|
|
|
"""
|
2025-08-18 14:05:59 +08:00
|
|
|
|
存储深圳交易所股票列表数据
|
2025-08-15 23:52:12 +08:00
|
|
|
|
|
2025-08-18 14:05:59 +08:00
|
|
|
|
不确定其数据爬取规则,防止 IP 被封
|
|
|
|
|
|
暂时使用该方案,获取股票列表数据
|
|
|
|
|
|
—— 下载excel,收到导入到数据库
|
2025-08-15 23:52:12 +08:00
|
|
|
|
"""
|
|
|
|
|
|
from pathlib import Path
|
2025-08-18 14:05:59 +08:00
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
from MySQLHelper import MySQLHelper
|
|
|
|
|
|
from LogHelper import LogHelper
|
|
|
|
|
|
import pandas as pd
|
2025-08-15 23:52:12 +08:00
|
|
|
|
import os
|
|
|
|
|
|
import sys
|
|
|
|
|
|
import csv
|
|
|
|
|
|
import chardet
|
|
|
|
|
|
|
2025-08-20 17:30:14 +08:00
|
|
|
|
logger = LogHelper(logger_name = 'SZ_Import').setup()
|
2025-08-15 23:52:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StockDataImporter:
|
|
|
|
|
|
"""股票数据导入工具(支持新版CSV格式)"""
|
|
|
|
|
|
|
|
|
|
|
|
# 新版CSV列名映射
|
|
|
|
|
|
COLUMN_MAPPING = {
|
|
|
|
|
|
'板块': 'market_type',
|
|
|
|
|
|
'公司全称': 'company_full_name',
|
|
|
|
|
|
'英文名称': 'eng_name',
|
|
|
|
|
|
'注册地址': 'registered_address',
|
|
|
|
|
|
'A股代码': 'a_stock_code',
|
|
|
|
|
|
'A股简称': 'a_stock_short_name',
|
|
|
|
|
|
'A股上市日期': 'a_listing_date',
|
|
|
|
|
|
'A股总股本': 'a_total_shares',
|
|
|
|
|
|
'A股流通股本': 'a_circulating_shares',
|
|
|
|
|
|
'B股代码': 'b_stock_code',
|
2025-08-18 14:05:59 +08:00
|
|
|
|
'B股 简 称': 'b_stock_short_name',
|
2025-08-15 23:52:12 +08:00
|
|
|
|
'B股上市日期': 'b_listing_date',
|
|
|
|
|
|
'B股总股本': 'b_total_shares',
|
|
|
|
|
|
'B股流通股本': 'b_circulating_shares',
|
2025-08-18 14:05:59 +08:00
|
|
|
|
'地 区': 'region',
|
|
|
|
|
|
'省 份': 'province',
|
|
|
|
|
|
'城 市': 'city',
|
2025-08-15 23:52:12 +08:00
|
|
|
|
'所属行业': 'industry',
|
|
|
|
|
|
'公司网址': 'website',
|
|
|
|
|
|
'未盈利': 'unprofitable',
|
|
|
|
|
|
'具有表决权差异安排': 'voting_rights_difference',
|
|
|
|
|
|
'具有协议控制架构': 'agreement_control_structure'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, data_dir: Path, db_config: dict):
|
|
|
|
|
|
self.data_dir = data_dir
|
|
|
|
|
|
self.db_config = db_config
|
|
|
|
|
|
self.df = None
|
|
|
|
|
|
self.csv_file = None
|
|
|
|
|
|
self.encoding = 'utf-8' # 默认编码
|
|
|
|
|
|
self.delimiter = ',' # 默认分隔符
|
|
|
|
|
|
|
|
|
|
|
|
def find_csv_file(self) -> Path:
|
|
|
|
|
|
"""在data文件夹中查找CSV文件"""
|
|
|
|
|
|
# 查找所有CSV文件
|
|
|
|
|
|
csv_files = list(self.data_dir.glob("A股列表.csv"))
|
|
|
|
|
|
|
|
|
|
|
|
if not csv_files:
|
|
|
|
|
|
logger.error(f"在 {self.data_dir} 中没有找到CSV文件")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# 如果有多个文件,选择最新的
|
|
|
|
|
|
if len(csv_files) > 1:
|
|
|
|
|
|
csv_files.sort(key=os.path.getmtime, reverse=True)
|
|
|
|
|
|
logger.info(f"找到多个CSV文件,选择最新的: {csv_files[0].name}")
|
|
|
|
|
|
|
|
|
|
|
|
return csv_files[0]
|
|
|
|
|
|
|
|
|
|
|
|
def validate_file(self, file_path: Path) -> bool:
|
|
|
|
|
|
"""验证CSV文件是否有效"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
if not file_path.exists():
|
|
|
|
|
|
logger.error(f"CSV文件不存在: {file_path}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
file_size = file_path.stat().st_size
|
|
|
|
|
|
if file_size == 0:
|
|
|
|
|
|
logger.error(f"CSV文件为空: {file_path}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"文件验证失败: {e}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def detect_file_encoding(self, file_path: Path) -> str:
|
|
|
|
|
|
"""检测文件编码"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 读取文件开头部分进行编码检测
|
|
|
|
|
|
with open(file_path, 'rb') as f:
|
|
|
|
|
|
raw_data = f.read(10000) # 读取前10KB
|
|
|
|
|
|
|
|
|
|
|
|
# 使用chardet检测编码
|
|
|
|
|
|
result = chardet.detect(raw_data)
|
|
|
|
|
|
encoding = result['encoding']
|
|
|
|
|
|
confidence = result['confidence']
|
|
|
|
|
|
|
|
|
|
|
|
# 常见编码替代
|
|
|
|
|
|
encoding_map = {
|
|
|
|
|
|
'GB2312': 'GBK',
|
|
|
|
|
|
'gb2312': 'GBK',
|
|
|
|
|
|
'ISO-8859-1': 'latin1',
|
|
|
|
|
|
'ascii': 'utf-8'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 应用映射
|
|
|
|
|
|
encoding = encoding_map.get(encoding, encoding)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"检测到编码: {encoding} (置信度: {confidence:.2f})")
|
|
|
|
|
|
return encoding or 'utf-8'
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"编码检测失败: {e}, 使用默认UTF-8")
|
|
|
|
|
|
return 'utf-8'
|
|
|
|
|
|
|
|
|
|
|
|
def detect_csv_delimiter(self, file_path: Path) -> str:
|
|
|
|
|
|
"""自动检测CSV分隔符"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 使用检测到的编码打开文件
|
|
|
|
|
|
with open(file_path, 'r', encoding=self.encoding) as f:
|
|
|
|
|
|
# 读取前5行
|
|
|
|
|
|
lines = [f.readline() for _ in range(5) if f.readline()]
|
|
|
|
|
|
|
|
|
|
|
|
# 尝试常见分隔符
|
|
|
|
|
|
delimiters = [',', '\t', ';', '|']
|
|
|
|
|
|
delimiter_counts = {}
|
|
|
|
|
|
|
|
|
|
|
|
for delim in delimiters:
|
|
|
|
|
|
count = 0
|
|
|
|
|
|
for line in lines:
|
|
|
|
|
|
count += line.count(delim)
|
|
|
|
|
|
delimiter_counts[delim] = count
|
|
|
|
|
|
|
|
|
|
|
|
# 选择出现次数最多的分隔符
|
|
|
|
|
|
best_delim = max(delimiter_counts, key=delimiter_counts.get)
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没有任何分隔符,则使用逗号
|
|
|
|
|
|
if delimiter_counts[best_delim] == 0:
|
|
|
|
|
|
logger.warning(f"无法检测到有效的分隔符,使用默认逗号分隔符")
|
|
|
|
|
|
return ','
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"检测到分隔符: {repr(best_delim)}")
|
|
|
|
|
|
return best_delim
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"检测分隔符失败: {e}, 使用默认逗号分隔符")
|
|
|
|
|
|
return ','
|
|
|
|
|
|
|
|
|
|
|
|
def read_csv_data(self, file_path: Path) -> bool:
|
|
|
|
|
|
"""从CSV文件读取数据"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 1. 检测文件编码
|
|
|
|
|
|
self.encoding = self.detect_file_encoding(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 检测分隔符
|
|
|
|
|
|
self.delimiter = self.detect_csv_delimiter(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 读取CSV文件
|
|
|
|
|
|
logger.info(f"使用编码 '{self.encoding}' 和分隔符 '{self.delimiter}' 读取文件")
|
|
|
|
|
|
|
|
|
|
|
|
self.df = pd.read_csv(
|
|
|
|
|
|
file_path,
|
|
|
|
|
|
delimiter=self.delimiter,
|
|
|
|
|
|
dtype=str,
|
|
|
|
|
|
encoding=self.encoding,
|
|
|
|
|
|
on_bad_lines='warn',
|
|
|
|
|
|
quoting=csv.QUOTE_MINIMAL,
|
|
|
|
|
|
engine='python' # 更健壮的引擎
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否读取到数据
|
|
|
|
|
|
if self.df.empty:
|
|
|
|
|
|
logger.error("CSV文件没有包含有效数据")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 重命名列
|
|
|
|
|
|
self.df = self.df.rename(columns=self.COLUMN_MAPPING)
|
|
|
|
|
|
|
|
|
|
|
|
# 移除可能存在的空行
|
|
|
|
|
|
self.df = self.df.dropna(how='all')
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"成功读取CSV数据,共 {len(self.df)} 条记录")
|
|
|
|
|
|
return True
|
|
|
|
|
|
except UnicodeDecodeError:
|
|
|
|
|
|
# 尝试其他编码
|
|
|
|
|
|
encodings_to_try = ['GBK', 'latin1', 'ISO-8859-1', 'utf-16']
|
|
|
|
|
|
for enc in encodings_to_try:
|
|
|
|
|
|
try:
|
|
|
|
|
|
logger.warning(f"尝试使用 {enc} 编码读取文件")
|
|
|
|
|
|
self.df = pd.read_csv(
|
|
|
|
|
|
file_path,
|
|
|
|
|
|
delimiter=self.delimiter,
|
|
|
|
|
|
dtype=str,
|
|
|
|
|
|
encoding=enc
|
|
|
|
|
|
)
|
|
|
|
|
|
self.encoding = enc
|
|
|
|
|
|
logger.info(f"成功使用 {enc} 编码读取文件")
|
|
|
|
|
|
return True
|
|
|
|
|
|
except:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
logger.error("所有编码尝试均失败")
|
|
|
|
|
|
return False
|
|
|
|
|
|
except PermissionError:
|
|
|
|
|
|
logger.error(f"文件被占用,请关闭后重试: {file_path}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"读取CSV文件失败: {e}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def clean_stock_data(self) -> bool:
|
|
|
|
|
|
"""清洗股票数据(修复了website字段的NaN处理问题)"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 处理数字字段中的逗号
|
|
|
|
|
|
numeric_columns = [
|
|
|
|
|
|
'a_total_shares', 'a_circulating_shares',
|
|
|
|
|
|
'b_total_shares', 'b_circulating_shares'
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
for col in numeric_columns:
|
|
|
|
|
|
if col in self.df.columns:
|
|
|
|
|
|
# 填充NaN为空字符串
|
|
|
|
|
|
self.df[col] = self.df[col].fillna('')
|
|
|
|
|
|
# 转换为字符串
|
|
|
|
|
|
self.df[col] = self.df[col].astype(str)
|
|
|
|
|
|
# 移除逗号和空格
|
|
|
|
|
|
self.df[col] = self.df[col].str.replace(',', '').str.replace(' ', '')
|
|
|
|
|
|
|
|
|
|
|
|
# 格式化日期字段
|
|
|
|
|
|
date_columns = ['a_listing_date', 'b_listing_date']
|
|
|
|
|
|
for col in date_columns:
|
|
|
|
|
|
if col in self.df.columns:
|
|
|
|
|
|
# 填充NaN为空字符串
|
|
|
|
|
|
self.df[col] = self.df[col].fillna('')
|
|
|
|
|
|
# 转换为datetime,无效日期转为NaT
|
|
|
|
|
|
self.df[col] = pd.to_datetime(
|
|
|
|
|
|
self.df[col],
|
|
|
|
|
|
errors='coerce'
|
|
|
|
|
|
).dt.strftime('%Y-%m-%d')
|
|
|
|
|
|
# 将NaT转换为空字符串
|
|
|
|
|
|
self.df[col] = self.df[col].replace('NaT', '')
|
|
|
|
|
|
|
|
|
|
|
|
# 处理布尔字段
|
|
|
|
|
|
bool_columns = ['unprofitable', 'voting_rights_difference', 'agreement_control_structure']
|
|
|
|
|
|
for col in bool_columns:
|
|
|
|
|
|
if col in self.df.columns:
|
|
|
|
|
|
# 填充NaN为0
|
|
|
|
|
|
self.df[col] = self.df[col].fillna('0')
|
|
|
|
|
|
# 将"-"转换为0/False
|
|
|
|
|
|
self.df[col] = self.df[col].replace('-', '0').replace('', '0')
|
|
|
|
|
|
# 转换为整数
|
|
|
|
|
|
self.df[col] = pd.to_numeric(self.df[col], errors='coerce').fillna(0).astype(int)
|
|
|
|
|
|
# 转换为布尔值
|
|
|
|
|
|
self.df[col] = self.df[col].astype(bool)
|
|
|
|
|
|
|
|
|
|
|
|
# 提取交易所信息
|
|
|
|
|
|
self.df['exchange'] = self.df['a_stock_code'].apply(
|
|
|
|
|
|
lambda x: 'SH' if str(x).startswith('60') else 'SZ' if str(x).startswith(('00', '30')) else 'OTHER'
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 验证A股代码格式
|
|
|
|
|
|
if 'a_stock_code' in self.df.columns:
|
|
|
|
|
|
# 填充NaN为空字符串
|
|
|
|
|
|
self.df['a_stock_code'] = self.df['a_stock_code'].fillna('')
|
|
|
|
|
|
# 转换为字符串
|
|
|
|
|
|
self.df['a_stock_code'] = self.df['a_stock_code'].astype(str)
|
|
|
|
|
|
|
|
|
|
|
|
invalid_codes = self.df[~self.df['a_stock_code'].str.match(r'^\d{6}$')]
|
|
|
|
|
|
if not invalid_codes.empty:
|
|
|
|
|
|
logger.warning(f"发现 {len(invalid_codes)} 条无效的A股代码")
|
|
|
|
|
|
logger.debug(f"无效代码示例: {invalid_codes['a_stock_code'].head().tolist()}")
|
|
|
|
|
|
|
|
|
|
|
|
# 清理网址字段 - 修复NaN处理问题
|
|
|
|
|
|
if 'website' in self.df.columns:
|
|
|
|
|
|
# 将NaN转换为空字符串
|
|
|
|
|
|
self.df['website'] = self.df['website'].fillna('')
|
|
|
|
|
|
# 转换为字符串类型
|
|
|
|
|
|
self.df['website'] = self.df['website'].astype(str)
|
|
|
|
|
|
|
|
|
|
|
|
# 执行字符串操作
|
|
|
|
|
|
self.df['website'] = self.df['website'].str.replace(' ', '').str.lower()
|
|
|
|
|
|
|
|
|
|
|
|
# 安全地添加http前缀
|
|
|
|
|
|
self.df['website'] = self.df['website'].apply(
|
|
|
|
|
|
lambda x: f'http://{x}' if x and not x.startswith('http') else x
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("数据清洗完成")
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"数据清洗失败: {e}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def create_stocks_table(self, db: MySQLHelper) -> bool:
|
|
|
|
|
|
"""创建股票信息表(新版)"""
|
|
|
|
|
|
create_table_sql = """
|
|
|
|
|
|
CREATE TABLE IF NOT EXISTS stocks_sz (
|
|
|
|
|
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
|
|
|
|
|
market_type VARCHAR(10) COMMENT '板块类型',
|
|
|
|
|
|
company_full_name VARCHAR(100) NOT NULL COMMENT '公司全称',
|
|
|
|
|
|
eng_name VARCHAR(150) COMMENT '英文名称',
|
|
|
|
|
|
registered_address VARCHAR(200) COMMENT '注册地址',
|
|
|
|
|
|
a_stock_code VARCHAR(6) NOT NULL COMMENT 'A股代码',
|
|
|
|
|
|
a_stock_short_name VARCHAR(20) NOT NULL COMMENT 'A股简称',
|
|
|
|
|
|
a_listing_date DATE COMMENT 'A股上市日期',
|
|
|
|
|
|
a_total_shares BIGINT COMMENT 'A股总股本',
|
|
|
|
|
|
a_circulating_shares BIGINT COMMENT 'A股流通股本',
|
|
|
|
|
|
b_stock_code VARCHAR(6) COMMENT 'B股代码',
|
|
|
|
|
|
b_stock_short_name VARCHAR(20) COMMENT 'B股简称',
|
|
|
|
|
|
b_listing_date DATE COMMENT 'B股上市日期',
|
|
|
|
|
|
b_total_shares BIGINT COMMENT 'B股总股本',
|
|
|
|
|
|
b_circulating_shares BIGINT COMMENT 'B股流通股本',
|
|
|
|
|
|
region VARCHAR(20) COMMENT '地区',
|
|
|
|
|
|
province VARCHAR(20) COMMENT '省份',
|
|
|
|
|
|
city VARCHAR(20) COMMENT '城市',
|
|
|
|
|
|
industry VARCHAR(50) COMMENT '所属行业',
|
|
|
|
|
|
website VARCHAR(100) COMMENT '公司网址',
|
|
|
|
|
|
unprofitable BOOLEAN DEFAULT 0 COMMENT '未盈利',
|
|
|
|
|
|
voting_rights_difference BOOLEAN DEFAULT 0 COMMENT '具有表决权差异安排',
|
|
|
|
|
|
agreement_control_structure BOOLEAN DEFAULT 0 COMMENT '具有协议控制架构',
|
|
|
|
|
|
exchange VARCHAR(2) COMMENT '交易所',
|
|
|
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
|
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
|
|
|
|
|
UNIQUE KEY (a_stock_code)
|
|
|
|
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='沪深股票详细信息表';
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
db.execute_update(create_table_sql)
|
2025-08-20 17:30:14 +08:00
|
|
|
|
logger.info("股票信息表打开成功")
|
2025-08-15 23:52:12 +08:00
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"创建表失败: {e}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def insert_data_to_db(self, db: MySQLHelper) -> bool:
|
|
|
|
|
|
"""将数据插入数据库"""
|
|
|
|
|
|
if self.df is None or self.df.empty:
|
|
|
|
|
|
logger.error("没有有效数据可插入")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 准备SQL语句(支持重复记录更新)
|
|
|
|
|
|
insert_sql = """
|
|
|
|
|
|
INSERT INTO stocks_sz (
|
|
|
|
|
|
market_type, company_full_name, eng_name, registered_address,
|
|
|
|
|
|
a_stock_code, a_stock_short_name, a_listing_date, a_total_shares, a_circulating_shares,
|
|
|
|
|
|
b_stock_code, b_stock_short_name, b_listing_date, b_total_shares, b_circulating_shares,
|
|
|
|
|
|
region, province, city, industry, website,
|
|
|
|
|
|
unprofitable, voting_rights_difference, agreement_control_structure,
|
|
|
|
|
|
exchange
|
|
|
|
|
|
) VALUES (
|
|
|
|
|
|
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
|
|
|
|
|
|
)
|
|
|
|
|
|
ON DUPLICATE KEY UPDATE
|
|
|
|
|
|
market_type = VALUES(market_type),
|
|
|
|
|
|
company_full_name = VALUES(company_full_name),
|
|
|
|
|
|
eng_name = VALUES(eng_name),
|
|
|
|
|
|
registered_address = VALUES(registered_address),
|
|
|
|
|
|
a_stock_short_name = VALUES(a_stock_short_name),
|
|
|
|
|
|
a_listing_date = VALUES(a_listing_date),
|
|
|
|
|
|
a_total_shares = VALUES(a_total_shares),
|
|
|
|
|
|
a_circulating_shares = VALUES(a_circulating_shares),
|
|
|
|
|
|
b_stock_code = VALUES(b_stock_code),
|
|
|
|
|
|
b_stock_short_name = VALUES(b_stock_short_name),
|
|
|
|
|
|
b_listing_date = VALUES(b_listing_date),
|
|
|
|
|
|
b_total_shares = VALUES(b_total_shares),
|
|
|
|
|
|
b_circulating_shares = VALUES(b_circulating_shares),
|
|
|
|
|
|
region = VALUES(region),
|
|
|
|
|
|
province = VALUES(province),
|
|
|
|
|
|
city = VALUES(city),
|
|
|
|
|
|
industry = VALUES(industry),
|
|
|
|
|
|
website = VALUES(website),
|
|
|
|
|
|
unprofitable = VALUES(unprofitable),
|
|
|
|
|
|
voting_rights_difference = VALUES(voting_rights_difference),
|
|
|
|
|
|
agreement_control_structure = VALUES(agreement_control_structure),
|
|
|
|
|
|
exchange = VALUES(exchange)
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# 准备参数列表
|
|
|
|
|
|
params_list = []
|
|
|
|
|
|
for _, row in self.df.iterrows():
|
|
|
|
|
|
# 处理空值
|
|
|
|
|
|
def get_value(col, default=None):
|
|
|
|
|
|
return row[col] if col in row and pd.notna(row[col]) else default
|
|
|
|
|
|
|
|
|
|
|
|
# 处理数字字段
|
|
|
|
|
|
def get_numeric(col, default=0):
|
|
|
|
|
|
value = get_value(col, default)
|
|
|
|
|
|
try:
|
|
|
|
|
|
return int(value) if value != '' and value is not None else default
|
|
|
|
|
|
except:
|
|
|
|
|
|
return default
|
|
|
|
|
|
|
|
|
|
|
|
# 处理日期字段
|
|
|
|
|
|
def get_date(col, default='1970-01-01'):
|
|
|
|
|
|
value = get_value(col, default)
|
|
|
|
|
|
if value in ['', None, 'NaT']:
|
|
|
|
|
|
return default
|
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
|
|
# 处理布尔字段
|
|
|
|
|
|
def get_bool(col, default=False):
|
|
|
|
|
|
value = get_value(col, default)
|
|
|
|
|
|
if value in [True, '1', 1, 'Y', 'y', '是']:
|
|
|
|
|
|
return True
|
|
|
|
|
|
if value in [False, '0', 0, 'N', 'n', '否', '-', '']:
|
|
|
|
|
|
return False
|
|
|
|
|
|
return default
|
|
|
|
|
|
|
|
|
|
|
|
params = (
|
|
|
|
|
|
get_value('market_type'), # market_type
|
|
|
|
|
|
get_value('company_full_name', ''), # company_full_name
|
|
|
|
|
|
get_value('eng_name'), # eng_name
|
|
|
|
|
|
get_value('registered_address'), # registered_address
|
|
|
|
|
|
get_value('a_stock_code', ''), # a_stock_code
|
|
|
|
|
|
get_value('a_stock_short_name', ''), # a_stock_short_name
|
|
|
|
|
|
get_date('a_listing_date'), # a_listing_date
|
|
|
|
|
|
get_numeric('a_total_shares', 0), # a_total_shares
|
|
|
|
|
|
get_numeric('a_circulating_shares', 0), # a_circulating_shares
|
|
|
|
|
|
get_value('b_stock_code'), # b_stock_code
|
|
|
|
|
|
get_value('b_stock_short_name'), # b_stock_short_name
|
|
|
|
|
|
get_date('b_listing_date'), # b_listing_date
|
|
|
|
|
|
get_numeric('b_total_shares', 0), # b_total_shares
|
|
|
|
|
|
get_numeric('b_circulating_shares', 0), # b_circulating_shares
|
|
|
|
|
|
get_value('region'), # region
|
|
|
|
|
|
get_value('province'), # province
|
|
|
|
|
|
get_value('city'), # city
|
|
|
|
|
|
get_value('industry'), # industry
|
|
|
|
|
|
get_value('website'), # website
|
|
|
|
|
|
get_bool('unprofitable'), # unprofitable
|
|
|
|
|
|
get_bool('voting_rights_difference'), # voting_rights_difference
|
|
|
|
|
|
get_bool('agreement_control_structure'), # agreement_control_structure
|
|
|
|
|
|
get_value('exchange', '') # exchange
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
params_list.append(params)
|
|
|
|
|
|
|
|
|
|
|
|
# 批量执行插入
|
|
|
|
|
|
try:
|
|
|
|
|
|
total_rows = len(params_list)
|
|
|
|
|
|
if total_rows == 0:
|
|
|
|
|
|
logger.error("没有有效数据可插入")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = 500 # 每批插入500条记录(因为字段较多)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"开始插入数据,共 {total_rows} 条记录")
|
|
|
|
|
|
|
|
|
|
|
|
# 分批插入,避免大事务问题
|
|
|
|
|
|
for i in range(0, total_rows, batch_size):
|
|
|
|
|
|
batch_params = params_list[i:i+batch_size]
|
|
|
|
|
|
affected_rows = db.execute_many(insert_sql, batch_params)
|
|
|
|
|
|
logger.info(f"已处理 {min(i+batch_size, total_rows)}/{total_rows} 条记录")
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"成功插入/更新 {total_rows} 条记录")
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"插入数据失败: {e}")
|
|
|
|
|
|
# 记录前5个参数以帮助调试
|
|
|
|
|
|
if params_list:
|
|
|
|
|
|
logger.debug(f"前5个参数示例: {params_list[:5]}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def verify_data_in_db(self, db: MySQLHelper, sample_size: int = 5) -> bool:
|
|
|
|
|
|
"""验证数据库中的数据"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 检查记录总数
|
|
|
|
|
|
count_sql = "SELECT COUNT(*) AS total FROM stocks_sz"
|
|
|
|
|
|
result = db.execute_query(count_sql)
|
|
|
|
|
|
db_count = result[0]['total'] if result else 0
|
|
|
|
|
|
logger.info(f"数据库中共有 {db_count} 条记录")
|
|
|
|
|
|
|
|
|
|
|
|
# 随机抽样检查
|
|
|
|
|
|
sample_sql = f"""
|
|
|
|
|
|
SELECT a_stock_code, a_stock_short_name, a_listing_date, province, city
|
|
|
|
|
|
FROM stocks_sz
|
|
|
|
|
|
ORDER BY RAND()
|
|
|
|
|
|
LIMIT {sample_size}
|
|
|
|
|
|
"""
|
|
|
|
|
|
samples = db.execute_query(sample_sql)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("\n随机抽样记录:")
|
|
|
|
|
|
for idx, sample in enumerate(samples, 1):
|
|
|
|
|
|
logger.info(f"{idx}. {sample['a_stock_code']}: {sample['a_stock_short_name']} ({sample['a_listing_date']}) - {sample['province']}{sample['city']}")
|
|
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"数据验证失败: {e}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def run_import(self) -> bool:
|
|
|
|
|
|
"""执行完整的导入流程"""
|
|
|
|
|
|
logger.info(f"开始导入股票数据,数据目录: {self.data_dir}")
|
|
|
|
|
|
start_time = datetime.now()
|
|
|
|
|
|
|
|
|
|
|
|
# 1. 查找CSV文件
|
|
|
|
|
|
csv_file = self.find_csv_file()
|
|
|
|
|
|
if not csv_file:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 验证文件
|
|
|
|
|
|
if not self.validate_file(csv_file):
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 读取CSV数据
|
|
|
|
|
|
if not self.read_csv_data(csv_file):
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 4. 清洗数据
|
|
|
|
|
|
if not self.clean_stock_data():
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 显示前5条数据
|
|
|
|
|
|
logger.info("\n前5条股票数据:")
|
|
|
|
|
|
for i, row in self.df.head().iterrows():
|
|
|
|
|
|
logger.info(f"{row['a_stock_code']}: {row['a_stock_short_name']} ({row['a_listing_date']}) - {row['province']}{row['city']}")
|
|
|
|
|
|
|
|
|
|
|
|
# 5. 连接数据库并导入
|
|
|
|
|
|
try:
|
|
|
|
|
|
with MySQLHelper(**self.db_config) as db:
|
|
|
|
|
|
# 5.1 创建表
|
|
|
|
|
|
if not self.create_stocks_table(db):
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 5.2 插入数据
|
|
|
|
|
|
if not self.insert_data_to_db(db):
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 5.3 验证数据
|
|
|
|
|
|
if not self.verify_data_in_db(db):
|
|
|
|
|
|
return False
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"数据库操作异常: {e}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 计算执行时间
|
|
|
|
|
|
duration = datetime.now() - start_time
|
|
|
|
|
|
logger.info(f"数据处理成功完成! 总耗时: {duration.total_seconds():.2f}秒")
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
# 数据库配置
|
|
|
|
|
|
DB_CONFIG = {
|
|
|
|
|
|
'host': 'localhost',
|
|
|
|
|
|
'user': 'root',
|
|
|
|
|
|
'password': 'bzskmysql',
|
|
|
|
|
|
'database': 'fullmarketdata_a'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 获取当前脚本所在目录
|
|
|
|
|
|
current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd()
|
|
|
|
|
|
|
|
|
|
|
|
# 设置数据目录
|
|
|
|
|
|
DATA_DIR = current_dir / "data"
|
|
|
|
|
|
|
|
|
|
|
|
# 确保data目录存在
|
|
|
|
|
|
DATA_DIR.mkdir(exist_ok=True, parents=True)
|
|
|
|
|
|
|
|
|
|
|
|
# 安装依赖 (如果chardet未安装)
|
|
|
|
|
|
try:
|
|
|
|
|
|
import chardet
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
logger.info("安装chardet库以支持编码检测...")
|
|
|
|
|
|
import subprocess
|
|
|
|
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "chardet"])
|
|
|
|
|
|
import chardet
|
|
|
|
|
|
|
|
|
|
|
|
# 创建导入器并执行导入
|
|
|
|
|
|
importer = StockDataImporter(DATA_DIR, DB_CONFIG)
|
|
|
|
|
|
|
|
|
|
|
|
if importer.run_import():
|
|
|
|
|
|
logger.info("股票数据导入成功!")
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.error("股票数据导入失败,请检查日志了解详情")
|