update exportdata
This commit is contained in:
@@ -143,10 +143,54 @@ class DataExporter:
|
||||
self.logger.error("无法获取月度均价数据")
|
||||
return False
|
||||
|
||||
# 读取港股通标记
|
||||
hk_inout_data = self.get_hk_inout()
|
||||
|
||||
for item in monthly_data:
|
||||
stock_name = item.get('stock_code')[3:]
|
||||
if stock_name in hk_inout_data:
|
||||
item['in_out'] = 1
|
||||
else:
|
||||
item['in_out'] = 0
|
||||
|
||||
# 导出结果
|
||||
file_path = 'data/' + csv_file if csv_file else None
|
||||
csv_success = True
|
||||
if csv_file:
|
||||
csv_success = self.export_to_csv(monthly_data, file_path)
|
||||
|
||||
return csv_success
|
||||
return csv_success
|
||||
|
||||
def get_hk_inout(self) -> Optional[List[Dict]]:
|
||||
"""
|
||||
从conditionalselection表读取流通股本数据
|
||||
|
||||
Args:
|
||||
table_name: 源数据表名
|
||||
|
||||
Returns:
|
||||
List[Dict]: 查询结果数据集,失败返回None
|
||||
"""
|
||||
try:
|
||||
with MySQLHelper(**self.db_config) as db:
|
||||
# 查询流通股本数据
|
||||
data = db.execute_query(f"""
|
||||
SELECT stock_code
|
||||
FROM hk_stock_connect
|
||||
WHERE in_out = '1'
|
||||
ORDER BY stock_code
|
||||
""")
|
||||
|
||||
if not data:
|
||||
self.logger.error(f"获取数据失败")
|
||||
return None
|
||||
|
||||
return [
|
||||
row['stock_code']
|
||||
for row in data
|
||||
if row['stock_code']
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"从数据库读取流通股本数据失败: {str(e)}")
|
||||
return None
|
||||
@@ -16,7 +16,8 @@ from typing import Optional, List, Dict, Union, Tuple
|
||||
import csv
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from base import LogHelper, MySQLHelper, Config
|
||||
from base import LogHelper, MySQLHelper, ConfigInfo
|
||||
|
||||
|
||||
class MarketDataCalculator:
|
||||
"""
|
||||
@@ -41,8 +42,8 @@ class MarketDataCalculator:
|
||||
"""
|
||||
self.db_config = db_config
|
||||
self.logger = LogHelper(logger_name=logger_name).setup()
|
||||
self.month_ranges = Config.ConfigInfo.MONTH_RANGES
|
||||
self.head_map = Config.ConfigInfo.HEADER_MAP
|
||||
self.month_ranges = ConfigInfo.MONTH_RANGES
|
||||
self.head_map = ConfigInfo.HEADER_MAP
|
||||
|
||||
def create_monthly_avg_table(self, target_table: str = "monthly_close_avg") -> bool:
|
||||
"""
|
||||
@@ -69,6 +70,7 @@ class MarketDataCalculator:
|
||||
ym_2506 DECIMAL(20, 5) COMMENT '2025年06月',
|
||||
ym_2507 DECIMAL(20, 5) COMMENT '2025年07月',
|
||||
ym_2508 DECIMAL(20, 5) COMMENT '2025年08月',
|
||||
ym_2509 DECIMAL(20, 5) COMMENT '2025年09月',
|
||||
avg_all DECIMAL(20, 5) COMMENT '月间均值',
|
||||
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
|
||||
UNIQUE KEY uk_stock_code (stock_code)
|
||||
@@ -82,7 +84,7 @@ class MarketDataCalculator:
|
||||
return False
|
||||
|
||||
def calculate_and_save_monthly_avg(self,
|
||||
source_table: str = "stock_quotes",
|
||||
stock_code: str = "code",
|
||||
target_table: str = "monthly_close_avg") -> bool:
|
||||
"""
|
||||
计算并保存2024年10月至2025年8月的月均流通市值
|
||||
@@ -100,79 +102,85 @@ class MarketDataCalculator:
|
||||
return False
|
||||
|
||||
with MySQLHelper(**self.db_config) as db:
|
||||
|
||||
# 获取所有股票代码和名称
|
||||
stock_info = db.execute_query(
|
||||
f"SELECT DISTINCT stock_code, stock_name FROM {source_table}"
|
||||
)
|
||||
|
||||
if not stock_info:
|
||||
self.logger.error("没有获取到股票基本信息")
|
||||
return False
|
||||
|
||||
# 为每只股票计算各月均值
|
||||
for stock in stock_info:
|
||||
stock_code = stock['stock_code']
|
||||
stock_name = stock['stock_name']
|
||||
|
||||
monthly_data = {'stock_code': stock_code, 'stock_name': stock_name}
|
||||
|
||||
# 计算每个月的均值
|
||||
for month_col, (start_date, end_date) in self.month_ranges.items():
|
||||
sql = """
|
||||
SELECT AVG(close_price * float_share) as avg_close
|
||||
FROM {}
|
||||
WHERE stock_code = %s
|
||||
AND trade_date BETWEEN %s AND %s
|
||||
AND close_price IS NOT NULL
|
||||
AND float_share IS NOT NULL
|
||||
""".format(source_table)
|
||||
|
||||
result = db.execute_query(sql, (stock_code, start_date, end_date))
|
||||
# 保存小数点后两位,以亿为单位
|
||||
# monthly_data[month_col] = float(result[0]['avg_close']) * 1000 if result and result[0]['avg_close'] else None
|
||||
monthly_data[month_col] = round(float(result[0]['avg_close']) * 1000 / 100000000, 3) if result and result[0]['avg_close'] else None
|
||||
|
||||
# 提取所有以 'ym_' 开头的键的值
|
||||
ym_values = [value for key, value in monthly_data.items() if key.startswith('ym_')]
|
||||
valid_ym_values = [value for value in ym_values if value is not None]
|
||||
|
||||
# 计算全部月的均值
|
||||
if valid_ym_values:
|
||||
average = sum(valid_ym_values) / len(valid_ym_values)
|
||||
monthly_data['avg_all'] = average
|
||||
self.logger.debug(f"股票 {stock_code} 月间流通市值平均值为: {average}")
|
||||
else:
|
||||
monthly_data['avg_all'] = 0 # 给一个空值,保证数据库不报错
|
||||
self.logger.warning(f"股票 {stock_code} 没有有效的月度数据")
|
||||
|
||||
# 插入或更新数据
|
||||
upsert_sql = f"""
|
||||
INSERT INTO {target_table} (
|
||||
stock_code, stock_name,
|
||||
ym_2501, ym_2502, ym_2503, ym_2504,
|
||||
ym_2505, ym_2506,ym_2507, ym_2508,
|
||||
avg_all
|
||||
) VALUES (
|
||||
%(stock_code)s, %(stock_name)s,
|
||||
%(ym_2501)s, %(ym_2502)s, %(ym_2503)s, %(ym_2504)s,
|
||||
%(ym_2505)s, %(ym_2506)s, %(ym_2507)s, %(ym_2508)s,
|
||||
%(avg_all)s
|
||||
)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
stock_name = VALUES(stock_name),
|
||||
ym_2501 = VALUES(ym_2501),
|
||||
ym_2502 = VALUES(ym_2502),
|
||||
ym_2503 = VALUES(ym_2503),
|
||||
ym_2504 = VALUES(ym_2504),
|
||||
ym_2505 = VALUES(ym_2505),
|
||||
ym_2506 = VALUES(ym_2506),
|
||||
ym_2507 = VALUES(ym_2507),
|
||||
ym_2508 = VALUES(ym_2508),
|
||||
avg_all = VALUES(avg_all),
|
||||
update_time = CURRENT_TIMESTAMP
|
||||
sql = """
|
||||
SELECT stock_code, stock_name
|
||||
FROM stock_filter
|
||||
WHERE stock_code = %s
|
||||
"""
|
||||
db.execute_update(upsert_sql, monthly_data)
|
||||
|
||||
# stock_filter表格可以当成标准表
|
||||
stock_info = db.execute_query(sql,(stock_code))
|
||||
if len(stock_info) == 0:
|
||||
return
|
||||
|
||||
|
||||
stock_code = stock_info[0]['stock_code']
|
||||
stock_name = stock_info[0]['stock_name']
|
||||
|
||||
monthly_data = {'stock_code': stock_code, 'stock_name': stock_name}
|
||||
|
||||
# 计算每个月的均值
|
||||
source_table = 'hk_' + stock_code[3:]
|
||||
for month_col, (start_date, end_date) in self.month_ranges.items():
|
||||
sql = """
|
||||
SELECT AVG(close_price * float_share) as avg_close
|
||||
FROM {}
|
||||
WHERE stock_code = %s
|
||||
AND trade_date BETWEEN %s AND %s
|
||||
AND close_price IS NOT NULL
|
||||
AND float_share IS NOT NULL
|
||||
""".format(source_table)
|
||||
|
||||
result = db.execute_query(sql, (stock_code, start_date, end_date))
|
||||
# 保存小数点后两位,以亿为单位
|
||||
# monthly_data[month_col] = float(result[0]['avg_close']) * 1000 if result and result[0]['avg_close'] else None
|
||||
monthly_data[month_col] = round(float(result[0]['avg_close']) * 1000 / 100000000, 3) if result and result[0]['avg_close'] else None
|
||||
|
||||
# 提取所有以 'ym_' 开头的键的值
|
||||
ym_values = [value for key, value in monthly_data.items() if key.startswith('ym_')]
|
||||
valid_ym_values = [value for value in ym_values if value is not None]
|
||||
|
||||
# 计算全部月的均值
|
||||
if valid_ym_values:
|
||||
average = sum(valid_ym_values) / len(valid_ym_values)
|
||||
monthly_data['avg_all'] = average
|
||||
self.logger.debug(f"股票 {stock_code} 月间流通市值平均值为: {average}")
|
||||
else:
|
||||
monthly_data['avg_all'] = 0 # 给一个空值,保证数据库不报错
|
||||
self.logger.warning(f"股票 {stock_code} 没有有效的月度数据")
|
||||
|
||||
# 插入或更新数据
|
||||
upsert_sql = f"""
|
||||
INSERT INTO {target_table} (
|
||||
stock_code, stock_name,
|
||||
ym_2501, ym_2502, ym_2503, ym_2504,
|
||||
ym_2505, ym_2506,ym_2507, ym_2508,
|
||||
ym_2509,
|
||||
avg_all
|
||||
) VALUES (
|
||||
%(stock_code)s, %(stock_name)s,
|
||||
%(ym_2501)s, %(ym_2502)s, %(ym_2503)s, %(ym_2504)s,
|
||||
%(ym_2505)s, %(ym_2506)s, %(ym_2507)s, %(ym_2508)s,
|
||||
%(ym_2509)s,
|
||||
%(avg_all)s
|
||||
)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
stock_name = VALUES(stock_name),
|
||||
ym_2501 = VALUES(ym_2501),
|
||||
ym_2502 = VALUES(ym_2502),
|
||||
ym_2503 = VALUES(ym_2503),
|
||||
ym_2504 = VALUES(ym_2504),
|
||||
ym_2505 = VALUES(ym_2505),
|
||||
ym_2506 = VALUES(ym_2506),
|
||||
ym_2507 = VALUES(ym_2507),
|
||||
ym_2508 = VALUES(ym_2508),
|
||||
ym_2509 = VALUES(ym_2509),
|
||||
avg_all = VALUES(avg_all),
|
||||
update_time = CURRENT_TIMESTAMP
|
||||
"""
|
||||
db.execute_update(upsert_sql, monthly_data)
|
||||
# self.logger.info("月度均值计算和保存完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -435,4 +443,6 @@ class MarketDataCalculator:
|
||||
if csv_file:
|
||||
csv_success = self.export_to_csv(monthly_data, file_path)
|
||||
|
||||
return csv_success
|
||||
return csv_success
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ class KLineUpdater:
|
||||
# 预处理数据
|
||||
processed_data = self.preprocess_quote_data(quote_data, float_share)
|
||||
if not processed_data:
|
||||
self.logger.error("没有有效数据需要保存")
|
||||
self.logger.error(f"没有有效数据需要保存,表:{table_name}")
|
||||
return False
|
||||
|
||||
# 动态生成SQL插入语句
|
||||
@@ -252,7 +252,7 @@ class KLineUpdater:
|
||||
self.logger.info(f"创建了新表: {table_name}")
|
||||
|
||||
affected_rows = db.execute_many(insert_sql, processed_data)
|
||||
self.logger.info(f"成功插入/更新 {affected_rows} 条行情记录到表 {table_name}")
|
||||
# self.logger.info(f"成功插入/更新 {affected_rows} 条行情记录到表 {table_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"保存行情数据到表 {table_name} 失败: {str(e)}")
|
||||
|
||||
@@ -23,7 +23,9 @@ class ConfigInfo:
|
||||
'ym_2506': '2025年06月',
|
||||
'ym_2507': '2025年07月',
|
||||
'ym_2508': '2025年08月',
|
||||
'avg_all': '月度均值'
|
||||
'ym_2509': '2025年09月',
|
||||
'avg_all': '月度均值',
|
||||
'in_out':'是否在港股通'
|
||||
}
|
||||
|
||||
# 月份范围配置
|
||||
@@ -35,6 +37,7 @@ class ConfigInfo:
|
||||
'ym_2505': ('2025-05-01', '2025-05-31'),
|
||||
'ym_2506': ('2025-06-01', '2025-06-30'),
|
||||
'ym_2507': ('2025-07-01', '2025-07-31'),
|
||||
'ym_2508': ('2025-08-01', '2025-08-31')
|
||||
'ym_2508': ('2025-08-01', '2025-08-31'),
|
||||
'ym_2509': ('2025-09-01', '2025-09-30')
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import pymysql
|
||||
from pymysql import Error
|
||||
from typing import List, Dict, Union, Optional, Tuple
|
||||
from contextlib import contextmanager
|
||||
from base.LogHelper import LogHelper
|
||||
from .LogHelper import LogHelper
|
||||
|
||||
# 基本用法(自动创建日期日志+控制台输出)
|
||||
logger = LogHelper(logger_name = 'database').setup()
|
||||
@@ -84,7 +84,9 @@ class MySQLHelper:
|
||||
"""
|
||||
try:
|
||||
self.cursor.execute(sql, params)
|
||||
return self.cursor.fetchall()
|
||||
result = self.cursor.fetchall()
|
||||
|
||||
return result
|
||||
except Error as e:
|
||||
logger.error(f"查询执行失败: {e}")
|
||||
return []
|
||||
@@ -233,4 +235,4 @@ class MySQLHelper:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
self.close()
|
||||
|
||||
499
base/StockDataImporter.py
Normal file
499
base/StockDataImporter.py
Normal file
@@ -0,0 +1,499 @@
|
||||
"""
|
||||
存储上海证券交易股票列表数据
|
||||
|
||||
不确定其数据爬取规则,防止 IP 被封
|
||||
暂时使用该方案,获取股票列表数据
|
||||
—— 下载excel,收到导入到数据库
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import os
|
||||
import sys
|
||||
import csv
|
||||
import chardet # 用于检测文件编码
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from .MySQLHelper import MySQLHelper
|
||||
from .LogHelper import LogHelper
|
||||
|
||||
logger = LogHelper(logger_name = 'execelImport').setup()
|
||||
|
||||
class StockDataImporter:
|
||||
"""股票数据导入工具(支持CSV)"""
|
||||
|
||||
def __init__(self, db_config: dict, column_mapping: dict, data_dir: Path):
|
||||
self.db_config = db_config
|
||||
self.column_mapping = column_mapping
|
||||
self.data_dir = data_dir
|
||||
self.df = None
|
||||
self.csv_file = None
|
||||
self.encoding = 'utf-8' # 默认编码
|
||||
self.delimiter = ',' # 默认分隔符
|
||||
self.upload_filename = None # 上传文件名
|
||||
# 更新 检讨标志
|
||||
|
||||
def setUploadfile(self, filename: str):
|
||||
"""设置需要上传的文件名"""
|
||||
self.upload_filename = filename
|
||||
logger.info(f"设置上传文件名为: {filename}")
|
||||
|
||||
def find_csv_file(self) -> Path:
|
||||
"""在data文件夹中查找CSV文件"""
|
||||
# 使用设置的upload_filename或默认文件名
|
||||
filename = self.upload_filename if self.upload_filename else "GPLIST.csv"
|
||||
|
||||
# 查找所有匹配的文件
|
||||
csv_files = list(self.data_dir.glob(filename))
|
||||
|
||||
if not csv_files:
|
||||
logger.error(f"在 {self.data_dir} 中没有找到文件: {filename}")
|
||||
return None
|
||||
|
||||
# 如果有多个文件,选择最新的
|
||||
if len(csv_files) > 1:
|
||||
csv_files.sort(key=os.path.getmtime, reverse=True)
|
||||
logger.info(f"找到多个文件,选择最新的: {csv_files[0].name}")
|
||||
|
||||
return csv_files[0]
|
||||
|
||||
def validate_file(self, file_path: Path) -> bool:
|
||||
"""验证CSV文件是否有效"""
|
||||
try:
|
||||
if not file_path.exists():
|
||||
logger.error(f"CSV文件不存在: {file_path}")
|
||||
return False
|
||||
|
||||
file_size = file_path.stat().st_size
|
||||
if file_size == 0:
|
||||
logger.error(f"CSV文件为空: {file_path}")
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"文件验证失败: {e}")
|
||||
return False
|
||||
|
||||
def detect_file_encoding(self, file_path: Path) -> str:
|
||||
"""检测文件编码"""
|
||||
try:
|
||||
# 读取文件开头部分进行编码检测
|
||||
with open(file_path, 'rb') as f:
|
||||
raw_data = f.read(10000) # 读取前10KB
|
||||
|
||||
# 使用chardet检测编码
|
||||
result = chardet.detect(raw_data)
|
||||
encoding = result['encoding']
|
||||
confidence = result['confidence']
|
||||
|
||||
# 常见编码替代
|
||||
encoding_map = {
|
||||
'GB2312': 'GBK',
|
||||
'gb2312': 'GBK',
|
||||
'ISO-8859-1': 'latin1',
|
||||
'ascii': 'utf-8'
|
||||
}
|
||||
|
||||
# 应用映射
|
||||
encoding = encoding_map.get(encoding, encoding)
|
||||
|
||||
logger.info(f"检测到编码: {encoding} (置信度: {confidence:.2f})")
|
||||
return encoding or 'utf-8'
|
||||
except Exception as e:
|
||||
logger.error(f"编码检测失败: {e}, 使用默认UTF-8")
|
||||
return 'utf-8'
|
||||
|
||||
def detect_csv_delimiter(self, file_path: Path) -> str:
|
||||
"""自动检测CSV分隔符"""
|
||||
try:
|
||||
# 使用检测到的编码打开文件
|
||||
with open(file_path, 'r', encoding=self.encoding) as f:
|
||||
# 读取前5行
|
||||
lines = [f.readline() for _ in range(5) if f.readline()]
|
||||
|
||||
# 尝试常见分隔符
|
||||
delimiters = [',', '\t', ';', '|']
|
||||
delimiter_counts = {}
|
||||
|
||||
for delim in delimiters:
|
||||
count = 0
|
||||
for line in lines:
|
||||
count += line.count(delim)
|
||||
delimiter_counts[delim] = count
|
||||
|
||||
# 选择出现次数最多的分隔符
|
||||
best_delim = max(delimiter_counts, key=delimiter_counts.get)
|
||||
|
||||
# 如果没有任何分隔符,则使用逗号
|
||||
if delimiter_counts[best_delim] == 0:
|
||||
logger.warning(f"无法检测到有效的分隔符,使用默认逗号分隔符")
|
||||
return ','
|
||||
|
||||
logger.info(f"检测到分隔符: {repr(best_delim)}")
|
||||
return best_delim
|
||||
except Exception as e:
|
||||
logger.error(f"检测分隔符失败: {e}, 使用默认逗号分隔符")
|
||||
return ','
|
||||
|
||||
def read_csv_data(self, file_path: Path) -> bool:
|
||||
"""从CSV文件读取数据"""
|
||||
try:
|
||||
# 1. 检测文件编码
|
||||
self.encoding = self.detect_file_encoding(file_path)
|
||||
|
||||
# 2. 检测分隔符
|
||||
self.delimiter = self.detect_csv_delimiter(file_path)
|
||||
|
||||
# 3. 读取CSV文件
|
||||
logger.info(f"使用编码 '{self.encoding}' 和分隔符 '{self.delimiter}' 读取文件")
|
||||
|
||||
self.df = pd.read_csv(
|
||||
file_path,
|
||||
delimiter=self.delimiter,
|
||||
dtype=str,
|
||||
encoding=self.encoding,
|
||||
on_bad_lines='warn',
|
||||
quoting=csv.QUOTE_MINIMAL,
|
||||
engine='python' # 更健壮的引擎
|
||||
)
|
||||
|
||||
# 检查是否读取到数据
|
||||
if self.df.empty:
|
||||
logger.error("CSV文件没有包含有效数据")
|
||||
return False
|
||||
|
||||
# 重命名列
|
||||
self.df = self.df.rename(columns=self.column_mapping)
|
||||
|
||||
# 移除可能存在的空行
|
||||
self.df = self.df.dropna(how='all')
|
||||
|
||||
logger.info(f"成功读取CSV数据,共 {len(self.df)} 条记录")
|
||||
return True
|
||||
except UnicodeDecodeError:
|
||||
# 尝试其他编码
|
||||
encodings_to_try = ['GBK', 'latin1', 'ISO-8859-1', 'utf-16']
|
||||
for enc in encodings_to_try:
|
||||
try:
|
||||
logger.warning(f"尝试使用 {enc} 编码读取文件")
|
||||
self.df = pd.read_csv(
|
||||
file_path,
|
||||
delimiter=self.delimiter,
|
||||
dtype=str,
|
||||
encoding=enc
|
||||
)
|
||||
self.encoding = enc
|
||||
logger.info(f"成功使用 {enc} 编码读取文件")
|
||||
return True
|
||||
except:
|
||||
continue
|
||||
|
||||
logger.error("所有编码尝试均失败")
|
||||
return False
|
||||
except PermissionError:
|
||||
logger.error(f"文件被占用,请关闭后重试: {file_path}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"读取CSV文件失败: {e}")
|
||||
return False
|
||||
|
||||
def clean_stock_data(self) -> bool:
|
||||
"""清洗股票数据(基本清洗,主要处理股票代码格式)"""
|
||||
try:
|
||||
# 验证股票代码格式(如果存在stock_code列)
|
||||
if 'stock_code' in self.df.columns:
|
||||
invalid_codes = self.df[~self.df['stock_code'].astype(str).str.match(r'^\d{6}$')]
|
||||
if not invalid_codes.empty:
|
||||
logger.warning(f"发现 {len(invalid_codes)} 条无效的股票代码")
|
||||
logger.debug(f"无效代码示例: {invalid_codes['stock_code'].head().tolist()}")
|
||||
|
||||
logger.info("数据清洗完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"数据清洗失败: {e}")
|
||||
return False
|
||||
|
||||
def create_stocks_table(self, db: MySQLHelper) -> bool:
|
||||
"""创建股票信息表(包含股票代码、中文名称、英文名称、进出标志和时间戳)"""
|
||||
# 定义列类型映射
|
||||
column_type_mapping = {
|
||||
'stock_code': 'VARCHAR(6) PRIMARY KEY',
|
||||
'stock_name_chn': 'VARCHAR(50) NULL',
|
||||
'stock_name_en': 'VARCHAR(150)',
|
||||
}
|
||||
|
||||
# 构建列定义SQL
|
||||
column_definitions = []
|
||||
for column_name in self.column_mapping.values():
|
||||
if column_name in column_type_mapping:
|
||||
column_definitions.append(f"{column_name} {column_type_mapping[column_name]}")
|
||||
else:
|
||||
# 对于未知列,使用VARCHAR(255)
|
||||
column_definitions.append(f"{column_name} VARCHAR(255)")
|
||||
|
||||
# 添加进出标志列
|
||||
column_definitions.append("in_out TINYINT(1) DEFAULT 0 COMMENT '进出标志'")
|
||||
|
||||
# 添加时间戳列
|
||||
column_definitions.append("created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间'")
|
||||
column_definitions.append("updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间'")
|
||||
|
||||
# 构建完整的CREATE TABLE SQL
|
||||
columns_sql = ",\n ".join(column_definitions)
|
||||
create_table_sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS hk_stock_connect (
|
||||
{columns_sql}
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='股票信息表';
|
||||
"""
|
||||
|
||||
try:
|
||||
db.execute_update(create_table_sql)
|
||||
logger.info("股票信息表创建成功")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"创建表失败: {e}")
|
||||
return False
|
||||
|
||||
def insert_data_to_db(self, db: MySQLHelper) -> bool:
|
||||
"""将数据插入数据库(处理映射的列和进出标志)"""
|
||||
if self.df is None or self.df.empty:
|
||||
logger.error("没有有效数据可插入")
|
||||
return False
|
||||
|
||||
# 获取所有映射的列名
|
||||
mapped_columns = list(self.column_mapping.values())
|
||||
|
||||
# 构建INSERT SQL语句(包含in_out列,设置为1)
|
||||
columns_sql = ", ".join(mapped_columns + ['in_out'])
|
||||
placeholders = ", ".join(["%s"] * len(mapped_columns) + ["1"]) # in_out固定为1
|
||||
|
||||
# 构建ON DUPLICATE KEY UPDATE部分(主键不更新,但更新in_out字段)
|
||||
update_clauses = []
|
||||
for column in mapped_columns:
|
||||
if column != 'stock_code': # 主键不更新
|
||||
update_clauses.append(f"{column} = VALUES({column})")
|
||||
# 添加in_out字段更新,确保在重复时也设置为1
|
||||
update_clauses.append("in_out = VALUES(in_out)")
|
||||
update_sql = ", ".join(update_clauses)
|
||||
|
||||
insert_sql = f"""
|
||||
INSERT INTO hk_stock_connect (
|
||||
{columns_sql}
|
||||
) VALUES (
|
||||
{placeholders}
|
||||
)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
{update_sql}
|
||||
"""
|
||||
|
||||
# 准备参数列表(只包含映射的列,in_out由SQL固定为1)
|
||||
params_list = []
|
||||
for _, row in self.df.iterrows():
|
||||
params = []
|
||||
for column in mapped_columns:
|
||||
# 处理可能的NaN值
|
||||
value = row[column] if column in row and pd.notna(row[column]) else None
|
||||
params.append(value)
|
||||
|
||||
params_list.append(tuple(params))
|
||||
|
||||
# 批量执行插入
|
||||
try:
|
||||
total_rows = len(params_list)
|
||||
if total_rows == 0:
|
||||
logger.error("没有有效数据可插入")
|
||||
return False
|
||||
|
||||
batch_size = 1000 # 每批插入1000条记录
|
||||
|
||||
logger.info(f"开始插入数据,共 {total_rows} 条记录")
|
||||
|
||||
# 分批插入,避免大事务问题
|
||||
for i in range(0, total_rows, batch_size):
|
||||
batch_params = params_list[i:i+batch_size]
|
||||
affected_rows = db.execute_many(insert_sql, batch_params)
|
||||
logger.info(f"已处理 {min(i+batch_size, total_rows)}/{total_rows} 条记录")
|
||||
|
||||
logger.info(f"成功插入/更新 {total_rows} 条记录")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"插入数据失败: {e}")
|
||||
# 记录前5个参数以帮助调试
|
||||
if params_list:
|
||||
logger.debug(f"前5个参数示例: {params_list[:5]}")
|
||||
return False
|
||||
|
||||
def setInOut(self, db: MySQLHelper) -> bool:
|
||||
"""设置进出标志为0,且不更新updated_at字段"""
|
||||
try:
|
||||
# # 首先检查表是否存在
|
||||
# table_check_sql = """
|
||||
# SELECT COUNT(*) AS table_exists
|
||||
# FROM information_schema.tables
|
||||
# WHERE table_schema = DATABASE()
|
||||
# AND table_name = 'hk_stock_connect'
|
||||
# """
|
||||
# result = db.execute_query(table_check_sql)
|
||||
|
||||
# if not result or result[0]['table_exists'] == 0:
|
||||
# logger.warning("表 'hk_stock_connect' 不存在,跳过设置进出标志操作")
|
||||
# return True
|
||||
|
||||
# 表存在,执行更新操作
|
||||
update_sql = "UPDATE hk_stock_connect SET in_out = 1, updated_at = updated_at"
|
||||
affected_rows = db.execute_update(update_sql)
|
||||
logger.info(f"成功设置 {affected_rows} 条记录的进出标志为0")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置进出标志失败: {e}")
|
||||
return False
|
||||
|
||||
def verify_data_in_db(self, db: MySQLHelper, sample_size: int = 5) -> bool:
|
||||
"""验证数据库中的数据"""
|
||||
try:
|
||||
# 检查记录总数
|
||||
count_sql = "SELECT COUNT(*) AS total FROM hk_stock_connect"
|
||||
result = db.execute_query(count_sql)
|
||||
db_count = result[0]['total'] if result else 0
|
||||
logger.info(f"数据库中共有 {db_count} 条记录")
|
||||
|
||||
# 获取映射的列名用于显示
|
||||
mapped_columns = list(self.column_mapping.values())
|
||||
|
||||
# 构建查询列(使用所有映射的列和in_out字段)
|
||||
select_columns = mapped_columns + ['in_out'] if mapped_columns else ["*"]
|
||||
columns_sql = ", ".join(select_columns)
|
||||
|
||||
# 随机抽样检查
|
||||
sample_sql = f"""
|
||||
SELECT {columns_sql}
|
||||
FROM hk_stock_connect
|
||||
ORDER BY RAND()
|
||||
LIMIT {sample_size}
|
||||
"""
|
||||
samples = db.execute_query(sample_sql)
|
||||
|
||||
logger.info("\n随机抽样记录:")
|
||||
for idx, sample in enumerate(samples, 1):
|
||||
# 动态构建日志消息,显示所有映射的列和in_out字段
|
||||
sample_info = []
|
||||
for column in select_columns:
|
||||
if column in sample and sample[column] is not None:
|
||||
sample_info.append(f"{column}: {sample[column]}")
|
||||
|
||||
logger.info(f"{idx}. {' | '.join(sample_info) if sample_info else 'No data to display'}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"数据验证失败: {e}")
|
||||
return False
|
||||
|
||||
def run_import(self) -> bool:
|
||||
"""执行完整的导入流程"""
|
||||
logger.info(f"开始导入股票数据,数据目录: {self.data_dir}")
|
||||
start_time = datetime.now()
|
||||
|
||||
# 1. 查找CSV文件
|
||||
csv_file = self.find_csv_file()
|
||||
if not csv_file:
|
||||
return False
|
||||
|
||||
# 2. 验证文件
|
||||
if not self.validate_file(csv_file):
|
||||
return False
|
||||
|
||||
# 3. 读取CSV数据
|
||||
if not self.read_csv_data(csv_file):
|
||||
return False
|
||||
|
||||
# 4. 清洗数据
|
||||
if not self.clean_stock_data():
|
||||
return False
|
||||
|
||||
# 显示前5条数据(动态显示可用列)
|
||||
logger.info("\n前5条股票数据:")
|
||||
for i, row in self.df.head().iterrows():
|
||||
# 动态构建显示信息
|
||||
display_info = []
|
||||
if 'stock_code' in row:
|
||||
display_info.append(f"代码: {row['stock_code']}")
|
||||
if 'stock_name_chn' in row:
|
||||
display_info.append(f"名称: {row['stock_name_chn']}")
|
||||
if 'stock_name_en' in row:
|
||||
display_info.append(f"英文: {row['stock_name_en']}")
|
||||
|
||||
logger.info(f"{i+1}. {' | '.join(display_info)}")
|
||||
|
||||
# 5. 连接数据库并导入
|
||||
try:
|
||||
with MySQLHelper(**self.db_config) as db:
|
||||
# 5.1 创建表
|
||||
if not self.create_stocks_table(db):
|
||||
return False
|
||||
|
||||
# 更新检讨结果
|
||||
self.setInOut(db)
|
||||
|
||||
# 5.2 插入数据
|
||||
if not self.insert_data_to_db(db):
|
||||
return False
|
||||
|
||||
# 5.3 验证数据
|
||||
if not self.verify_data_in_db(db):
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"数据库操作异常: {e}")
|
||||
return False
|
||||
|
||||
# 计算执行时间
|
||||
duration = datetime.now() - start_time
|
||||
logger.info(f"数据处理成功完成! 总耗时: {duration.total_seconds():.2f}秒")
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# 数据库配置
|
||||
db_config = {
|
||||
'host': 'localhost',
|
||||
'user': 'root',
|
||||
'password': 'bzskmysql',
|
||||
'database': 'hk_kline_1d'
|
||||
}
|
||||
|
||||
# 列映射配置
|
||||
COLUMN_MAPPING = {
|
||||
'证券代码': 'stock_code',
|
||||
'中文简称': 'stock_name_chn',
|
||||
'英文简称': 'stock_name_en',
|
||||
}
|
||||
|
||||
# 获取当前脚本所在目录
|
||||
current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd()
|
||||
|
||||
# 设置数据目录
|
||||
DATA_DIR = current_dir / "data"
|
||||
|
||||
# 确保data目录存在
|
||||
DATA_DIR.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# 安装依赖 (如果chardet未安装)
|
||||
try:
|
||||
import chardet
|
||||
except ImportError:
|
||||
logger.info("安装chardet库以支持编码检测...")
|
||||
import subprocess
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "chardet"])
|
||||
import chardet
|
||||
|
||||
# 创建导入器并执行导入(使用新的参数顺序)
|
||||
importer = StockDataImporter(db_config, COLUMN_MAPPING, DATA_DIR)
|
||||
|
||||
|
||||
|
||||
# 2.2w设置上传文件名(可以注释掉使用默认文件名)
|
||||
importer.setUploadfile("港股通标的证券名单.csv")
|
||||
|
||||
if importer.run_import():
|
||||
logger.info("股票数据导入成功!")
|
||||
else:
|
||||
logger.error("股票数据导入失败,请检查日志了解详情")
|
||||
@@ -1,3 +1,4 @@
|
||||
from .MySQLHelper import MySQLHelper
|
||||
from .LogHelper import LogHelper
|
||||
from .Config import ConfigInfo
|
||||
from .Config import ConfigInfo
|
||||
from .StockDataImporter import StockDataImporter
|
||||
@@ -895,4 +895,8 @@ HK.08218
|
||||
HK.06960
|
||||
HK.02936
|
||||
HK.03858
|
||||
HK.08132
|
||||
HK.08132
|
||||
HK.02938
|
||||
HK.02580
|
||||
HK.02941
|
||||
HK.02935
|
||||
2
config/Removecode.txt
Normal file
2
config/Removecode.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
HK.04335
|
||||
HK.02292
|
||||
@@ -686,7 +686,6 @@ HK.02179
|
||||
HK.00442
|
||||
HK.01959
|
||||
HK.01985
|
||||
HK.02992
|
||||
HK.00314
|
||||
HK.01459
|
||||
HK.01082
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
HK.02938
|
||||
@@ -1,5 +1,5 @@
|
||||
kevin_futu:
|
||||
已使用 1000 行情
|
||||
已使用 1000 行情 -> 移除 HK.02992
|
||||
hang_futu:
|
||||
已使用 703 行情
|
||||
HK牛仔:
|
||||
|
||||
86
main_gui.py
86
main_gui.py
@@ -102,6 +102,9 @@ class MainWindow(QMainWindow):
|
||||
# 创建功能按钮组
|
||||
self.create_button_group(main_layout)
|
||||
|
||||
# 创建数据导入按钮组
|
||||
self.create_import_button_group(main_layout)
|
||||
|
||||
# 创建进度条
|
||||
self.progress_bar = QProgressBar()
|
||||
# self.progress_bar.setVisible(False)
|
||||
@@ -307,6 +310,20 @@ class MainWindow(QMainWindow):
|
||||
|
||||
button_group.setLayout(button_layout)
|
||||
layout.addWidget(button_group)
|
||||
|
||||
def create_import_button_group(self, layout):
|
||||
"""创建数据导入按钮组"""
|
||||
import_group = QGroupBox("数据导入")
|
||||
import_layout = QHBoxLayout()
|
||||
|
||||
# 导入数据按钮
|
||||
self.btn_import = QPushButton('导入股票数据')
|
||||
self.btn_import.clicked.connect(self.on_import_clicked)
|
||||
self.btn_import.setToolTip('从CSV文件导入股票数据到数据库')
|
||||
import_layout.addWidget(self.btn_import)
|
||||
|
||||
import_group.setLayout(import_layout)
|
||||
layout.addWidget(import_group)
|
||||
|
||||
def create_log_area(self, layout):
|
||||
"""创建日志显示区域"""
|
||||
@@ -341,6 +358,8 @@ class MainWindow(QMainWindow):
|
||||
self.btn_export.setEnabled(enabled)
|
||||
self.btn_calculate.setEnabled(enabled)
|
||||
self.btn_check.setEnabled(enabled)
|
||||
self.btn_import.setEnabled(enabled)
|
||||
self.btn_float_share.setEnabled(enabled)
|
||||
|
||||
def on_update_clicked(self):
|
||||
"""更新数据按钮点击事件"""
|
||||
@@ -401,6 +420,17 @@ class MainWindow(QMainWindow):
|
||||
worker.finished_signal.connect(self.on_task_finished)
|
||||
worker.start()
|
||||
self.worker_threads.append(worker)
|
||||
|
||||
def on_import_clicked(self):
|
||||
"""导入数据按钮点击事件"""
|
||||
self.log_message("开始导入股票数据...")
|
||||
self.set_buttons_enabled(False)
|
||||
|
||||
worker = WorkerThread(self.import_stock_data)
|
||||
worker.log_signal.connect(self.log_message)
|
||||
worker.finished_signal.connect(self.on_task_finished)
|
||||
worker.start()
|
||||
self.worker_threads.append(worker)
|
||||
|
||||
def on_task_finished(self, success, message):
|
||||
"""任务完成回调"""
|
||||
@@ -486,8 +516,9 @@ class MainWindow(QMainWindow):
|
||||
|
||||
# 移除人民币交易的股票:股票名称最后一个字符为R,误删除的从配置文件读回来
|
||||
reserved_codes = calculator.read_stock_codes_list(Path.cwd() / "config" / "Reservedcode.txt")
|
||||
remove_codes = calculator.read_stock_codes_list(Path.cwd() / "config" / "Removecode.txt")
|
||||
market_data_ll = calculator.get_stock_codes() # 使用按照价格和流通股数量筛选的那个表格
|
||||
market_data = market_data_ll + reserved_codes
|
||||
market_data = market_data_ll + reserved_codes - remove_codes
|
||||
|
||||
# 根据统计时间进行命名
|
||||
target_table_name = 'hk_monthly_avg_' + datetime.now().strftime("%Y%m%d")
|
||||
@@ -496,11 +527,10 @@ class MainWindow(QMainWindow):
|
||||
|
||||
# 使用tqdm创建进度条
|
||||
for code in tqdm(market_data, desc="处理股票数据", unit="支"):
|
||||
tablename = 'hk_' + code[3:]
|
||||
# 计算并保存月度均值
|
||||
calculator.calculate_and_save_monthly_avg(
|
||||
source_table=tablename,
|
||||
target_table=target_table_name
|
||||
stock_code =code,
|
||||
target_table = target_table_name
|
||||
)
|
||||
|
||||
# self.log_message("月度平均计算完成")
|
||||
@@ -529,6 +559,54 @@ class MainWindow(QMainWindow):
|
||||
except Exception as e:
|
||||
self.log_message(f"数据检查失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def import_stock_data(self):
|
||||
"""导入股票数据任务"""
|
||||
try:
|
||||
# 导入必要的模块
|
||||
from base.StockDataImporter import StockDataImporter
|
||||
from base.MySQLHelper import MySQLHelper
|
||||
from pathlib import Path
|
||||
|
||||
# 数据库配置
|
||||
db_config = {
|
||||
'host': 'localhost',
|
||||
'user': 'root',
|
||||
'password': 'bzskmysql',
|
||||
'database': 'hk_kline_1d'
|
||||
}
|
||||
|
||||
# 列映射配置
|
||||
COLUMN_MAPPING = {
|
||||
'证券代码': 'stock_code',
|
||||
'中文简称': 'stock_name_chn',
|
||||
'英文简称': 'stock_name_en',
|
||||
}
|
||||
|
||||
# 设置数据目录
|
||||
current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd()
|
||||
DATA_DIR = current_dir / "data"
|
||||
|
||||
# 确保data目录存在
|
||||
DATA_DIR.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# 创建导入器
|
||||
importer = StockDataImporter(db_config, COLUMN_MAPPING, DATA_DIR)
|
||||
|
||||
# 设置上传文件名(使用默认文件名 "港股通标的证券名单.csv")
|
||||
importer.setUploadfile("港股通标的证券名单.csv")
|
||||
|
||||
# 执行导入
|
||||
if importer.run_import():
|
||||
self.log_message("股票数据导入成功!")
|
||||
return True
|
||||
else:
|
||||
self.log_message("股票数据导入失败,请检查日志了解详情")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.log_message(f"导入股票数据失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def closeEvent(self, event):
|
||||
"""窗口关闭事件"""
|
||||
|
||||
Reference in New Issue
Block a user