使用python连接数据库时,频繁手写连接器使人心烦,因此将其包装为一个类,我们需要使用时只需调用它即可。
import re
import logging
from time import time
from abc import ABC, abstractmethod
from impala.dbapi import connect
from decorator import decorator
__all__ = ["Impala", "Hive"]
class LoggerFactory:
    """Factory to create logger."""
    @staticmethod
    def stream(name):
        """Stream logger."""
        assert hasattr(name, '__name__') or isinstance(name, str), 'Input must be string or has attribute `__name__`.'
        if hasattr(name, '__name__'):
            name = name.__name__.upper()
        else:
            name = name.upper()
        logger = logging.getLogger(name)
        if not logger.handlers:
            logger.setLevel(logging.DEBUG)
            fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            formatter = logging.Formatter(fmt)
            sh = logging.StreamHandler()
            sh.setFormatter(formatter)
            logger.addHandler(sh)
        return logger
    @staticmethod
    def file(name):
        """File logger."""
        logger = logging.getLogger(name)
        if not logger.handlers:
            logger.setLevel(logging.DEBUG)
            fmt = '%(asctime)s - %(levelname)s - %(message)s'
            formatter = logging.Formatter(fmt)
            fh = logging.FileHandler(f'{name}.log', encoding='utf-8')
            fh.setFormatter(formatter)
            logger.addHandler(fh)
        return logger
    @staticmethod
    def both(name):
        """Stream and file logger."""
        logger = logging.getLogger(name)
        if not logger.handlers:
            logger.setLevel(logging.DEBUG)
            fmt = '%(asctime)s - %(levelname)s - %(message)s'
            formatter = logging.Formatter(fmt)
            fh = logging.FileHandler(f'{name}.log', encoding='utf-8')
            fh.setFormatter(formatter)
            logger.addHandler(fh)
            sh = logging.StreamHandler()
            sh.setFormatter(formatter)
            logger.addHandler(sh)
        return logger
@decorator
def timer(func, logger=None, *args, **kwargs):
    """Calculate how long the function runs."""
    if not logger:
        logger = LoggerFactory.stream(func)
    start = int(round(time() * 1000))
    logger.info("Start")
    result = func(*args, **kwargs)
    end = int(round(time() * 1000)) - start
    end /= 1000
    m, s = divmod(end, 60)
    h, m = divmod(m, 60)
    logger.info("Done")
    logger.info("Total execution time: %d:%02d:%02d" % (h, m, s))
    return result
class BaseConnector(ABC):
    """Connect and get data from hadoop database."""
    def __init__(self, name):
        """Start a session and create a cursor."""
        self.conn = connect(host=f"{name}.hostserver.com", auth_mechanism="PLAIN",
                            port=YOUR PORT, user="YOUR USERNAME", password="YOUR PASSWORD")
        self.cursor = self.conn.cursor()
    @abstractmethod
    def _run(self, sql):
        """Run one sql."""
        pass
    @staticmethod
    def _split(sql):
        sqls = sql.split(';')
        if not sqls[-1].strip():
            sqls = sqls[:-1]
        return sqls
    @timer
    def execute(self, sql):
        """Run multiple sqls."""
        assert isinstance(sql, str) or isinstance(sql, list), "Invalid type of parameter 'sql'."
        if isinstance(sql, str):
            sqls = self._split(sql)
        else:
            sqls = sql
        for x in sqls:
            self._run(x)
    @timer
    def fetch(self, sql, tag=None, to_pandas=True):
        """
        Run one sql and get the result.
        Parameters
        ----------
        sql: str
            * tag=None, one query
            * tag!=None, multiple queries
        tag: str, default None
            Select exactly one query by comment like '/*COMMENT*/' locate at the start and get the result.
        to_pandas: boolean, default True
            Return pandas DataFrame or list.
        Examples
        --------
        >>> s = "describe ods.table; /*I want this one*/select * from ods.table2;"
        >>> with Impala() as db:
        >>>     df = db.fetch(s, tag="I want this one")
        """
        if tag:
            sqls = [x.strip() for x in self._split(sql)]
            sqls = list(filter(lambda x: x.startswith(f"/*{tag}*/"), sqls))
            assert len(sqls) == 1, "请检查tag名,重复或不存在"
            sql = sqls.pop()
        self._run(sql)
        if to_pandas:
            from impala.util import as_pandas
            return as_pandas(self.cursor)
        else:
            names = tuple([metadata[0] for metadata in self.cursor.description])
            result = self.cursor.fetchall()
            result.insert(0, names)
            return result
    def close(self):
        self.conn.close()
    def __call__(self, sql, to_pandas=True):
        """Run multiple sqls and return the last block's result."""
        sqls = self._split(sql)
        if sqls[:-1]:
            self.execute(sqls[:-1])
        return self.fetch(sqls[-1], to_pandas=to_pandas)
    def __enter__(self):
        return self
    def __exit__(self, *args):
        self.close()
class Impala(BaseConnector):
    def __init__(self):
        BaseConnector.__init__(self, "impala")
    def _run(self, sql):
        self.cursor.execute(sql)
    @timer
    def upload(self, data, name, schema, dtypes, if_exists="fail"):
        """上传数据至impala数据库
        Parameters
        ----------
        data: pd.DataFrame or List of tuples
            Data.
        name: str
            Name of SQL table.
        schema: str
            Specify the schema.
        dtypes: dict
            Specify each columns' data type.
        if_exists: {"fail", "replace", "append"}, default "fail"
            How to behave if the table already exists.
            * fail: Raise an AssertError.
            * replace: Drop the table before inserting new values.
            * append: Insert new values to the existing table.
        Examples
        --------
        >>> with Impala() as db:
        >>>     db.upload(data, "test_table", "test", dtypes={"dt": "string", "userid": "bigint", "amount": "float"})
        """
        assert if_exists in ("fail", "replace", "append"), "if_exists的值错误"
        if if_exists == "replace":
            self._run(f"drop table if exists {schema}.{name}")
            table = ", ".join([f"{i} {v}" for i, v in dtypes.items()])
            self._run(f"create table {schema}.{name} ({table})")
        else:
            if if_exists == "append":
                assert self.conn.table_exists(name, schema), "该表不存在"
            else:
                assert not self.conn.table_exists(name, schema), "该表已存在"
                table = ", ".join([f"{i} {v}" for i, v in dtypes.items()])
                self._run(f"create table {schema}.{name} ({table})")
        columns = str(tuple(dtypes.keys())).replace("'", "`")
        if isinstance(data, list):
            values = data
        else:
            values = data.apply(lambda x: tuple(x), axis=1).values.tolist()
        values = str(values).strip("[]")
        self._run(f"insert into {schema}.{name} {columns} values {values}")
class Hive(BaseConnector):
    def __init__(self):
        BaseConnector.__init__(self, "hive")
    def _run(self, sql):
        sql = re.sub(r'(/\*.*?\*/)', ' ', sql,
                     flags=re.S)  # Delete comment like /**/ since some version of hive don't support this.
        self.cursor.execute(sql)
使用方法如下所示:
# 例1
with Impala() as db:
    db.execute("**some sql**")
    df = db.fetch("**some sql**")
# 例2
with Impala() as db:
    df = db("**some sql**")
# 例3
db = Impala()
db.fetch("**some sql**")
db.close()
 
                   
      
      
 
       
      
