2012年11月28日 星期三

《Stocktotal》之五:抓取個股股利政策

抓取資料三大步驟:
  1. Source web contents
  2. Parse web contents
  3. Insert database


1. Source Web Contents

Link: http://estockweb.standardchartered.com.tw/z/zc/zcc/zcc_1101.djhtm

我們稍微把 1101 改成 2498,就變成 2498的股利政策,表示我們可以透過這些網路資料來擷取個股股利政策。

擷取方式採用 third-party 軟體:wget。wget 有 Windows 版本及 Mac 版本,移植性不用擔心。那麼該怎麼用 Python 寫出跨平台的程式呢?這個網路上有很多資源,我也是抄別人的做法加以改良。

首先,先寫 interface,wget.py:
import platform

os = platform.system()
if os == 'Windows':
    from . import wget_win as wget_private
elif os == 'Darwin':
    from . import wget_mac as wget_private
else:
    raise Exception('Please add support for your platform')

def wget(cmdline):
    wget_private.wget(cmdline)
就是一個 wget method,舉例來說:
url --waitretry=3 -O dest_file
抓哪個網址,waitretry 設為 3 ((這是很重要的設定,畢竟網路環境不穩,無法無限制的等待,也不能不等待,我們得選中庸之道)),存在哪個目的檔。

程式中,我們利用 platform module 得到目前跑什麼平台。例如是 Windows 平台,那我就寫 wget_win.py,有點 d-pointer 的味道。
import os

def wget(cmdline):
    wget = os.path.abspath('./core/thirdparty/wget/wget.exe')
    assert os.path.isfile(wget)  
    wget_cmdline = '''{wget} {cmdline}'''.format(wget=wget, cmdline=cmdline)
    os.system(wget_cmdline)

wget_mac.py
import os

def wget(cmdline):
    wget_cmdline = '''wget {cmdline}'''.format(cmdline=cmdline)
    os.system(wget_cmdline)

就是有那麼微微的不一樣。



2. Parse Web Contents

這是最複雜的部分,每個 web contents 可能差不多,也可能差很多。因為 web contents 多是 HTML 格式,關於取資料,lxml module 功能算十分強大,當然還有更強大的。

但中文網頁實在是有夠混亂的,特別是 encoding,我們得多試幾種 decoder。舉例來說,讀取 html content 的程式片斷如下
src_fd = open(src_file, 'rb')
src_content = src_fd.read()
src_fd.close()

content = None
try:
    content = html.fromstring(src_content.decode('big5-hkscs'))
except UnicodeDecodeError as e:
    self.LOGGER.debug(e)
    content = html.fromstring(src_content.decode('gb18030'))
如果還沒辦法 decode,那可就麻煩了,目前網頁這兩個 decoders 就夠用了。

如果熟悉 xpath 的用法,接下來取資料就很容易了。通常這些網頁是用程式生出來的,格式自然比較規矩。xpath 程式片斷
for table in content.xpath('//html/body/div/table/tr/td[@width="99%"]/table/tr/td/table/tr/td/table'):
    for yearly_dataset in table.xpath('./tr'):
        yearly_data = yearly_dataset.xpath('./td/text()')
        if len(yearly_data) is 7:
            activity_date = self.get_date(yearly_data[0])
            if not activity_date:
                continue
            record = [
                self.STOCK_CODE, 
                activity_date, 
                self.get_double(yearly_data[1]),
                self.get_double(yearly_data[2]),
                self.get_double(yearly_data[3]),
                self.get_double(yearly_data[4]),
                self.get_double(yearly_data[5]),
                self.get_double(yearly_data[6]),
            ]
總之知道基本技巧後,其它就是自由發揮,不會就 google 唄。



3. Insert Database

以 PostgreSQL 為例,我們得先準備好 schema
create table if not exists StockDividend
(
    creation_dt timestamp default current_timestamp,
    stock_code text not null,
    activity_date date not null, 
    cash_dividend double precision,
    stock_dividend_from_retained_earnings double precision,
    stock_dividend_from_capital_reserve double precision,
    stock_dividend double precision,
    total_dividend double precision,
    profit_sharing_percentage double precision,
    unique (stock_code, activity_date)
);

用 pgAdmin III 執行 SQL statement,並且留意執行身分,記得把權限開給適當的人。接著就是用 py-postgresql module 執行 PostgreSQL 操作。


這邊又有點小東西可以講。

Program to an interface, not an implementation


直接操作 PostgreSQL 或許方便,但將來想改用 MySQL 或是 SQL Server,得考慮程式好不好改,因為當初我也從 SQLite 改到 PostgreSQL。

db_config.py
DB_TYPE = 'postgres'
如果將來想換 database,直接在這邊設定,加入適當的 implementation 就可以了。


insertion/insertion_factory.py
from .. import db_config
db_type = db_config.DB_TYPE

if db_type == 'sqlite':
    from .sqlite import insertion_factory as factory_private
elif db_type == 'postgres':
    from .postgres import insertion_factory as factory_private
    
class InsertionFactory():
    @staticmethod
    def insertion():
        return factory_private.InsertionFactory().insertion()
對於 insert,我們利用 factory pattern,根據 DB_TYPE,生成適當的 Insertion object。這裡我比較龜毛,最外面的 factory 只會生出 PostgreSQL 專屬的 factory,最後怎麼生出來,由專屬的 factory 來煩惱。

insertion/postgresql/insertion_factory.py
from . import insertion
    
class InsertionFactory():
    @staticmethod
    def insertion():
        return insertion.Insertion()

insertion/postgresql/insertion.py
import postgresql

class Insertion():

    def __init__(self):
        self.CONN_STRING = 'pq://stocktotal:stocktotal@localhost:5432/stocktotal'
        self.DB_CONN = None
        
    def open(self):
        self.DB_CONN = postgresql.open(self.CONN_STRING)
        
    def close(self):
        self.DB_CONN.close()
        self.DB_CONN = None

    def insert(self, sql_cmd, record):
        try:
            fixed_record = [None if _.strip() == '' else _ for _ in record]
            prepared_stmt = self.DB_CONN.prepare(sql_cmd)
            prepared_stmt(*fixed_record)
        except postgresql.exceptions.UniqueError:
            pass

    def insert_stock_dividend(self, record):
        sql_cmd = \
            '''
            insert into StockDividend(
                stock_code,
                activity_date,
                cash_dividend,
                stock_dividend_from_retained_earnings,
                stock_dividend_from_capital_reserve,
                stock_dividend,
                total_dividend,
                profit_sharing_percentage
            ) values(
                $1, 
                $2::text::date, 
                $3::text::float8, 
                $4::text::float8,
                $5::text::float8,
                $6::text::float8,
                $7::text::float8,
                $8::text::float8
            )
            '''
        self.insert(sql_cmd, record)
外面 client user 只需傳 list 進來就可以了。當然這樣的 interface 有點死板,第一個是股號,第二個是日期,第三個是等等之類的,不過好用就好了,將來要改,也不會改的太疲憊。



附錄:

standardchartered_source.py
import csv
import logging
import os
from datetime import date

class StandardcharteredSource():

    def __init__(self):
        from ..db.insertion import insertion_factory
        
        self.LOGGER = logging.getLogger()        
        self.URL_TEMPLATE = ''
        self.STOCK_CODE = None
        self.SOURCE_TYPE = None
        self.HTML_DIR = ''
        self.CSV_DIR = ''
        self.DB_INSERTION = insertion_factory.InsertionFactory().insertion()
        
    def source(self, stock_code):
        self.STOCK_CODE = stock_code
        self.source_url_to_html(self.HTML_DIR)        
        self.source_html_to_csv(self.HTML_DIR, self.CSV_DIR)
        self.source_csv_to_db(self.SOURCE_TYPE, self.CSV_DIR, self.DB_INSERTION)

    def source_url_to_html(self, dest_dir):
        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)
        url = self.__get_url()
        dest_file = self.get_filename(dest_dir, 'html')
        self.__wget(url, dest_file)
        
    def source_html_to_csv(self, src_dir, dest_dir):
        pass

    def source_csv_to_db(self, source_type, src_dir, db_insertion):            
        src_file = self.get_filename(src_dir, 'csv')
        if not os.path.isfile(src_file):
            return    
        self.LOGGER.debug('''{src_file} => db'''.format(src_file=src_file))
        fd = open(src_file, 'r')
        csv_reader = csv.reader(fd)
                
        INSERT_SOURCE_TYPE_MAP = {
            'capital_structure_summary': db_insertion.insert_capital_structure_summary,
            'stock_dividend': db_insertion.insert_stock_dividend,
        }                 
                
        db_insertion.open()
        for r in csv_reader:
            INSERT_SOURCE_TYPE_MAP[source_type](r)
            self.LOGGER.debug(r)
        db_insertion.close()
        
        fd.close()

    def get_filename(self, src_dir, ext):
        return os.path.join(src_dir, self.STOCK_CODE + '.' + ext) 

    # Get date from ROC year to date (Python data type)
    def get_date(self, literal):
        try:
            return date(int(literal) + 1911, 1, 1)
        except ValueError:
            return None        

    def get_double(self, literal):
        literal = literal.replace(',','')    
        try:
            return float(literal)
        except ValueError:
            return None
            
    def __get_url(self):
        return self.URL_TEMPLATE % self.STOCK_CODE

    def __wget(self, url, dest_file):
        from ..base import wget
        cmdline = '''\"{url}\" --waitretry=3 -O \"{dest_file}\"'''.format(url=url, dest_file=dest_file)
        wget.wget(cmdline)

stock_dividend_source.py
import csv
import os
from lxml import html

from . import standardchartered_source

class StockDividendSource(standardchartered_source.StandardcharteredSource):

    def __init__(self):
        standardchartered_source.StandardcharteredSource.__init__(self)

        self.URL_TEMPLATE = '''http://estockweb.standardchartered.com.tw/z/zc/zcc/zcc_%s.djhtm'''
        self.SOURCE_TYPE = 'stock_dividend'
        self.HTML_DIR = '../dataset/stock_dividend/html/'
        self.CSV_DIR = '../dataset/stock_dividend/csv/'

    def source_html_to_csv(self, src_dir, dest_dir):
        assert os.path.isdir(src_dir)
        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)

        src_file = self.get_filename(src_dir, 'html')
        dest_file = self.get_filename(dest_dir, 'csv')
        self.LOGGER.debug('''{src_file} => {dest_file}'''.format(src_file=src_file, dest_file=dest_file))
        assert os.path.isfile(src_file)
        
        dest_fd = open(dest_file, 'w', newline='')
        csv_writer = csv.writer(dest_fd)
        
        src_fd = open(src_file, 'rb')
        src_content = src_fd.read()
        src_fd.close()

        content = None
        try:
            content = html.fromstring(src_content.decode('big5-hkscs').replace('&nbsp;', ' ').replace('<BR>', ''))
        except UnicodeDecodeError as e:
            self.LOGGER.debug(e)
            content = html.fromstring(src_content.decode('gb18030').replace('&nbsp;', ' ').replace('<BR>', ''))

        for table in content.xpath('//html/body/div/table/tr/td[@width="99%"]/table/tr/td/table/tr/td/table'):
            for yearly_dataset in table.xpath('./tr'):
                yearly_data = yearly_dataset.xpath('./td/text()')
                if len(yearly_data) is 7:
                    activity_date = self.get_date(yearly_data[0])
                    if not activity_date:
                        continue
                    record = [
                        self.STOCK_CODE, 
                        activity_date, 
                        self.get_double(yearly_data[1]),
                        self.get_double(yearly_data[2]),
                        self.get_double(yearly_data[3]),
                        self.get_double(yearly_data[4]),
                        self.get_double(yearly_data[5]),
                        self.get_double(yearly_data[6]),
                    ]
                    csv_writer.writerow(record)
        dest_fd.close()

沒有留言:

張貼留言