430 lines
13 KiB
Python
430 lines
13 KiB
Python
"""
|
||
使用 akshare 数数据源获取/更新A股数据
|
||
"""
|
||
from MySQLHelper import MySQLHelper # 导入我们创建的助手类
|
||
from LogHelper import LogHelper
|
||
import pandas as pd
|
||
import akshare as ak
|
||
import re
|
||
import time
|
||
from datetime import datetime, timedelta
|
||
import logging
|
||
|
||
# 基本用法(自动创建日期日志+控制台输出)
|
||
logger = LogHelper(logger_name = 'AkShare').setup()
|
||
|
||
# 数据库配置信息 股票列表
|
||
DB_CONFIG = {
|
||
'host': 'localhost',
|
||
'user': 'root',
|
||
'password': 'bzskmysql',
|
||
'database': 'fullmarketdata_a',
|
||
'port': 3306,
|
||
'charset': 'utf8mb4'
|
||
}
|
||
|
||
# 日K数据库
|
||
DB_CONFIG_1D = {
|
||
'host': 'localhost',
|
||
'user': 'root',
|
||
'password': 'bzskmysql',
|
||
'database': 'klinedata_1d_ma',
|
||
'port': 3306,
|
||
'charset': 'utf8mb4'
|
||
}
|
||
|
||
# 方法1:显式连接和关闭
|
||
def get_SH_stock_codes() -> list:
|
||
"""
|
||
从数据库中获取所有 a_stock_code 值
|
||
|
||
返回:
|
||
list: 包含所有股票代码的列表
|
||
"""
|
||
# 创建数据库助手实例
|
||
db = MySQLHelper(**DB_CONFIG)
|
||
|
||
try:
|
||
# 连接数据库
|
||
if not db.connect():
|
||
logger.error("数据库连接失败")
|
||
return []
|
||
|
||
# 执行查询
|
||
results = db.execute_query("SELECT a_stock_code FROM stocks_sh")
|
||
|
||
# 提取股票代码
|
||
stock_codes = [row['a_stock_code'] for row in results if row['a_stock_code']]
|
||
|
||
return stock_codes
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票代码时出错: {e}")
|
||
return []
|
||
|
||
finally:
|
||
# 确保关闭数据库连接
|
||
db.close()
|
||
|
||
def get_SZ_stock_codes() -> list:
|
||
"""
|
||
从数据库中获取所有 a_stock_code 值
|
||
|
||
返回:
|
||
list: 包含所有股票代码的列表
|
||
"""
|
||
# 创建数据库助手实例
|
||
db = MySQLHelper(**DB_CONFIG)
|
||
|
||
try:
|
||
# 连接数据库
|
||
if not db.connect():
|
||
logger.error("数据库连接失败")
|
||
return []
|
||
|
||
# 执行查询
|
||
results = db.execute_query("SELECT a_stock_code FROM stocks_sz")
|
||
|
||
# 提取股票代码
|
||
stock_codes = [row['a_stock_code'] for row in results if row['a_stock_code']]
|
||
|
||
return stock_codes
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票代码时出错: {e}")
|
||
return []
|
||
|
||
finally:
|
||
# 确保关闭数据库连接
|
||
db.close()
|
||
|
||
# 方法2:使用上下文管理器(推荐)
|
||
def get_SH_stock_codes_with_context() -> list:
|
||
"""
|
||
使用上下文管理器获取所有 a_stock_code 值
|
||
|
||
返回:
|
||
list: 包含所有股票代码的列表
|
||
"""
|
||
# 使用上下文管理器自动处理连接
|
||
with MySQLHelper(**DB_CONFIG) as db:
|
||
try:
|
||
# 执行查询
|
||
results = db.execute_query("SELECT a_stock_code FROM stocks_sh")
|
||
|
||
# 提取股票代码
|
||
return [row['a_stock_code'] for row in results if row['a_stock_code']]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票代码时出错: {e}")
|
||
return []
|
||
|
||
def get_SZ_stock_codes_with_context() -> list:
|
||
"""
|
||
使用上下文管理器获取所有 a_stock_code 值
|
||
|
||
返回:
|
||
list: 包含所有股票代码的列表
|
||
"""
|
||
# 使用上下文管理器自动处理连接
|
||
with MySQLHelper(**DB_CONFIG) as db:
|
||
try:
|
||
# 执行查询
|
||
results = db.execute_query("SELECT a_stock_code FROM stocks_sz")
|
||
|
||
# 提取股票代码
|
||
return [row['a_stock_code'] for row in results if row['a_stock_code']]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票代码时出错: {e}")
|
||
return []
|
||
|
||
def get_daily_k_data(stock_code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
||
"""
|
||
获取单只股票的日K线数据
|
||
|
||
参数:
|
||
stock_code: 格式化后的股票代码 (如 sh600000)
|
||
start_date: 开始日期 (YYYYMMDD)
|
||
end_date: 结束日期 (YYYYMMDD)
|
||
|
||
返回:
|
||
DataFrame: 包含日K线数据的DataFrame
|
||
"""
|
||
try:
|
||
# 获取股票历史行情数据
|
||
df = ak.stock_zh_a_hist(
|
||
symbol=stock_code,
|
||
period="daily",
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
adjust="qfq" # 前复权
|
||
)
|
||
|
||
# 如果数据为空,尝试使用原始代码
|
||
if df.empty and not stock_code.startswith(('sh', 'sz', 'bj')):
|
||
logger.info(f"尝试使用原始代码: {stock_code}")
|
||
df = ak.stock_zh_a_hist(
|
||
symbol=stock_code,
|
||
period="daily",
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
adjust="qfq"
|
||
)
|
||
|
||
# 重命名列
|
||
if not df.empty:
|
||
df.columns = [
|
||
'date', 'open', 'close', 'high', 'low',
|
||
'volume', 'amount', 'amplitude', 'change_percent',
|
||
'change_amount', 'turnover'
|
||
]
|
||
df['code'] = stock_code # 添加股票代码列
|
||
|
||
return df
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取 {stock_code} 日K数据时出错: {e}")
|
||
return pd.DataFrame()
|
||
|
||
def format_stock_code(code: str) -> str:
|
||
"""
|
||
格式化股票代码为akshare需要的格式
|
||
|
||
规则:
|
||
- 6开头: 上海证券交易所 (sh)
|
||
- 0或3开头: 深圳证券交易所 (sz)
|
||
- 4或8开头: 北京证券交易所 (bj)
|
||
|
||
返回: 交易所前缀 + 股票代码
|
||
"""
|
||
# 如果代码已经是带前缀的格式,直接返回
|
||
if code.startswith(('sh', 'sz', 'bj')):
|
||
return code
|
||
|
||
# 根据数字前缀判断交易所
|
||
if code.startswith('6'):
|
||
return f"sh{code}"
|
||
elif code.startswith(('0', '3')):
|
||
return f"sz{code}"
|
||
elif code.startswith(('4', '8')):
|
||
return f"bj{code}"
|
||
else:
|
||
logger.error(f"无法识别的股票代码格式: {code}")
|
||
return code # 返回原始格式,让akshare尝试处理
|
||
|
||
def get_daily_k_data(stock_code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
||
"""
|
||
获取单只股票的日K线数据
|
||
|
||
参数:
|
||
stock_code: 格式化后的股票代码 (如 sh600000)
|
||
start_date: 开始日期 (YYYYMMDD)
|
||
end_date: 结束日期 (YYYYMMDD)
|
||
|
||
返回:
|
||
DataFrame: 包含日K线数据的DataFrame
|
||
"""
|
||
try:
|
||
# 获取股票历史行情数据
|
||
df = ak.stock_zh_a_hist(
|
||
symbol=stock_code,
|
||
period="daily",
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
adjust="qfq" # 前复权
|
||
)
|
||
|
||
# 重命名列
|
||
if not df.empty:
|
||
df.columns = [
|
||
'date', 'code', 'open', 'close', 'high', 'low',
|
||
'volume', 'amount', 'amplitude', 'change_percent',
|
||
'change_amount', 'turnover'
|
||
]
|
||
return df
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取 {stock_code} 日K数据时出错: {e}")
|
||
return pd.DataFrame()
|
||
|
||
def create_stock_table(db: MySQLHelper, table_name: str) -> bool:
|
||
"""
|
||
创建股票数据表
|
||
|
||
参数:
|
||
db: 数据库连接
|
||
table_name: 表名 (格式: 交易所_股票代码, 如 sh_600000)
|
||
|
||
返回:
|
||
bool: 是否成功
|
||
"""
|
||
# 检查表名是否合法
|
||
if not re.match(r"^[a-z]{2}_[0-9]{6}$", table_name):
|
||
print(f"表名 '{table_name}' 不符合命名规则")
|
||
return False
|
||
|
||
# 创建表SQL
|
||
create_table_sql = f"""
|
||
CREATE TABLE IF NOT EXISTS `{table_name}` (
|
||
`date` DATE NOT NULL COMMENT '日期',
|
||
`code` DECIMAL(10, 2) NOT NULL COMMENT '代码',
|
||
`open` DECIMAL(10, 2) NOT NULL COMMENT '开盘价',
|
||
`close` DECIMAL(10, 2) NOT NULL COMMENT '收盘价',
|
||
`high` DECIMAL(10, 2) NOT NULL COMMENT '最高价',
|
||
`low` DECIMAL(10, 2) NOT NULL COMMENT '最低价',
|
||
`volume` BIGINT NOT NULL COMMENT '成交量(手)',
|
||
`amount` DECIMAL(20, 2) NOT NULL COMMENT '成交额(元)',
|
||
`amplitude` DECIMAL(5, 2) NOT NULL COMMENT '振幅(%)',
|
||
`change_percent` DECIMAL(5, 2) NOT NULL COMMENT '涨跌幅(%)',
|
||
`change_amount` DECIMAL(5, 2) NOT NULL COMMENT '涨跌额(元)',
|
||
`turnover` DECIMAL(5, 2) NOT NULL COMMENT '换手率(%)',
|
||
PRIMARY KEY (`date`),
|
||
INDEX `idx_date` (`date`)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='股票日K数据表';
|
||
"""
|
||
|
||
try:
|
||
db.execute_update(create_table_sql)
|
||
return True
|
||
except Exception as e:
|
||
print(f"创建表 {table_name} 失败: {e}")
|
||
return False
|
||
|
||
def save_stock_data_to_db(db: MySQLHelper, df: pd.DataFrame, table_name: str) -> int:
|
||
"""
|
||
将股票数据保存到数据库表中
|
||
|
||
参数:
|
||
db: 数据库连接
|
||
df: 包含股票数据的DataFrame
|
||
table_name: 表名 (格式: 交易所_股票代码, 如 sh_600000)
|
||
|
||
返回:
|
||
int: 成功插入的记录数
|
||
"""
|
||
if df.empty:
|
||
return 0
|
||
|
||
# 检查表名是否合法
|
||
if not re.match(r"^[a-z]{2}_[0-9]{6}$", table_name):
|
||
print(f"表名 '{table_name}' 不符合命名规则")
|
||
return 0
|
||
|
||
# 准备插入SQL
|
||
insert_sql = f"""
|
||
INSERT INTO `{table_name}` (
|
||
date, code, open, close, high, low,
|
||
volume, amount, amplitude, change_percent,
|
||
change_amount, turnover
|
||
) VALUES (
|
||
%s, %s, %s, %s, %s,
|
||
%s, %s, %s, %s,
|
||
%s, %s, %s
|
||
) ON DUPLICATE KEY UPDATE
|
||
code = VALUES(code),
|
||
open = VALUES(open),
|
||
close = VALUES(close),
|
||
high = VALUES(high),
|
||
low = VALUES(low),
|
||
volume = VALUES(volume),
|
||
amount = VALUES(amount),
|
||
amplitude = VALUES(amplitude),
|
||
change_percent = VALUES(change_percent),
|
||
change_amount = VALUES(change_amount),
|
||
turnover = VALUES(turnover)
|
||
"""
|
||
|
||
# 准备数据
|
||
data_to_insert = []
|
||
for _, row in df.iterrows():
|
||
# 确保日期格式正确
|
||
date_value = row['date']
|
||
# if len(date_str) == 10: # YYYY-MM-DD
|
||
# date_value = date_str
|
||
# else:
|
||
# try:
|
||
# date_value = datetime.strptime(date_str, '%Y-%m-%d').strftime('%Y-%m-%d')
|
||
# except:
|
||
# # 尝试其他日期格式
|
||
# date_value = date_str[:10] # 取前10个字符
|
||
|
||
data_to_insert.append((
|
||
date_value, row['code'], row['open'], row['close'],
|
||
row['high'], row['low'], row['volume'], row['amount'],
|
||
row['amplitude'], row['change_percent'],
|
||
row['change_amount'], row['turnover']
|
||
))
|
||
|
||
# 批量插入
|
||
if data_to_insert:
|
||
try:
|
||
affected_rows = db.execute_many(insert_sql, data_to_insert)
|
||
print(f"表 {table_name}: 成功插入/更新 {affected_rows} 条记录")
|
||
return affected_rows
|
||
except Exception as e:
|
||
print(f"保存数据到表 {table_name} 失败: {e}")
|
||
return 0
|
||
return 0
|
||
|
||
def generate_table_name(stock_code: str) -> str:
|
||
"""
|
||
根据股票代码生成表名 (格式: 交易所_股票代码)
|
||
|
||
参数:
|
||
stock_code: 股票代码 (带或不带交易所前缀)
|
||
|
||
返回:
|
||
str: 表名 (如 sh_600000)
|
||
"""
|
||
if stock_code.startswith('6'):
|
||
return f"sh_{stock_code}"
|
||
elif stock_code.startswith(('0', '3')):
|
||
return f"sz_{stock_code}"
|
||
elif stock_code.startswith(('4', '8')):
|
||
return f"bj_{stock_code}"
|
||
|
||
# 默认处理
|
||
return f"unknown_{stock_code}"
|
||
|
||
if __name__ == "__main__":
|
||
|
||
# 读取股票代码
|
||
logger.info("读取股票代码")
|
||
sh_stock_codes_context = get_SH_stock_codes_with_context()
|
||
sz_stock_codes_context = get_SZ_stock_codes_with_context()
|
||
all_stock_codes = sh_stock_codes_context + sz_stock_codes_context
|
||
|
||
if all_stock_codes:
|
||
logger.info(f"前五个代码:{all_stock_codes[:5]}")
|
||
logger.info(f"后五个代码:{all_stock_codes[-6:-1]}")
|
||
|
||
# 存储日K数据
|
||
db_1d = MySQLHelper(**DB_CONFIG_1D)
|
||
if not db_1d.connect():
|
||
logger.error("数据库连接失败")
|
||
|
||
# 获取最近3年的数据
|
||
start_date = (datetime.now() - timedelta(days = 3 * 365)).strftime("%Y%m%d")
|
||
end_date = (datetime.now() + timedelta(days = 1)).strftime("%Y%m%d")
|
||
logger.info(f"获取数据时间范围: {start_date} 至 {end_date}")
|
||
|
||
nCount = 0
|
||
for code in all_stock_codes:
|
||
nCount = nCount+1
|
||
if nCount < 1584:
|
||
continue
|
||
df = get_daily_k_data(code,start_date,end_date)
|
||
|
||
# 生成表名 (交易所_股票代码)
|
||
table_name = generate_table_name(code)
|
||
|
||
# 创建表(如果不存在)
|
||
if not create_stock_table(db_1d, table_name):
|
||
logger.error(f"无法为股票 {code} 创建表 {table_name}")
|
||
|
||
# 保存数据到表
|
||
save_stock_data_to_db(db_1d, df, table_name)
|
||
|
||
# 添加延迟,避免请求过快
|
||
time.sleep(5) |