G. B.

个人博客

且将新火试新茶,诗酒趁年华


平均分配问题

工作中有遇到一种平均分配问题,贷后环节给催收员分案时,既要保证每个人分到的案件量近似,又要保证每个人分到的案件总金额近似,这是一个NP难问题。

下面是一个近似求解的方法。算法思路为:

假设总案件量为N,需要分给m个人

  1. 先将所有案件随机均分为m组,如果N不能被m整除,则给原案件列表后补0,补足长度至能被m整除为止
  2. 进入大循环,寻找当前金额总和最大的组和最小的组
  3. 进入小循环,随机交换两组中的一个元素
  4. 判断极差是否减小,如果减小,则跳出小循环回到步骤2,如果未减小,则继续小循环回到步骤3
  5. 当极差小于等于设定值时,跳出大循环,结束
import logging
from itertools import count


class LoggerFactory:
    """Factory to create logger."""

    @staticmethod
    def stream(name):
        """Stream logger."""
        assert hasattr(name, '__name__') or isinstance(name, str), 'Input must be string or has attribute `__name__`.'
        if hasattr(name, '__name__'):
            name = name.__name__.upper()
        else:
            name = name.upper()
        logger = logging.getLogger(name)
        if not logger.handlers:
            logger.setLevel(logging.DEBUG)
            fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            formatter = logging.Formatter(fmt)
            sh = logging.StreamHandler()
            sh.setFormatter(formatter)
            logger.addHandler(sh)
        return logger

    @staticmethod
    def file(name):
        """File logger."""
        logger = logging.getLogger(name)
        if not logger.handlers:
            logger.setLevel(logging.DEBUG)
            fmt = '%(asctime)s - %(levelname)s - %(message)s'
            formatter = logging.Formatter(fmt)
            fh = logging.FileHandler(f'{name}.log', encoding='utf-8')
            fh.setFormatter(formatter)
            logger.addHandler(fh)
        return logger

    @staticmethod
    def both(name):
        """Stream and file logger."""
        logger = logging.getLogger(name)
        if not logger.handlers:
            logger.setLevel(logging.DEBUG)
            fmt = '%(asctime)s - %(levelname)s - %(message)s'
            formatter = logging.Formatter(fmt)
            fh = logging.FileHandler(f'{name}.log', encoding='utf-8')
            fh.setFormatter(formatter)
            logger.addHandler(fh)
            sh = logging.StreamHandler()
            sh.setFormatter(formatter)
            logger.addHandler(sh)
        return logger


def balance(index, value, group, drift=100, max_iter=500, max_exchange=10000, logger=None, verbose=False):
    """
    平均分配问题,保证均分结果中每组个数与值的总和都大致相似

    Parameters
    ----------
    index: 1D array
        与值相对应的索引
    value: 1D array
        需要进行均分的值
    group: int
        均分的组数
    drift: float, optional
        允许分组结果的极差最大值
    max_iter: int, optional
        最大迭代次数
    max_exchange: int, optional
        最大交换次数
    logger: logger, optional
        日志器
    verbose: boolean
        是否打印详情

    Returns
    -------
    index : 1D array
        均分的值结果所对应的索引
    value : 1D array
        均分的值结果
    """
    import numpy as np

    if not logger:
        logger = LoggerFactory.stream(balance)

    index = index.copy()
    value = value.copy()
    if len(value) % group != 0:
        index = np.append(index, [0] * (group - len(value) % group))
        value = np.append(value, [0] * (group - len(value) % group))
    ncol = len(value) // group
    index = np.reshape(index, (group, ncol))
    value = np.reshape(value, (group, ncol))

    diff = value.sum(axis=1).ptp()
    logger.info('=' * 50)
    for epoch in count(1, step=1):
        if diff <= drift:
            logger.info('Success.')
            break
        elif epoch > max_iter:
            logger.warning('Reach max iter.')
            break
        max_group_index = value.sum(axis=1).argmax()
        min_group_index = value.sum(axis=1).argmin()
        for i in range(max_exchange):
            x = np.random.randint(0, ncol)
            y = np.random.randint(0, ncol)
            index[max_group_index, x], index[min_group_index, y] = index[min_group_index, y], index[max_group_index, x]
            value[max_group_index, x], value[min_group_index, y] = value[min_group_index, y], value[max_group_index, x]
            if value.sum(axis=1).ptp() < diff:
                diff = value.sum(axis=1).ptp()
                if verbose:
                    logger.info(f'Iter {epoch}, done after {i + 1} exchange.')
                break
        else:
            logger.warning(f'Iter {epoch}, reach max exchange.')
    logger.info(f'Total value in each group:\n{value.sum(axis=1)}')
    logger.info(f'Total number in each group:\n{[len(x) for x in [[x for x in row if x != 0] for row in index.tolist()]]}')
    return index, value

打赏一个呗

取消

感谢您的支持,我会继续努力的!

扫码支持
扫码支持
扫码打赏,你说多少就多少

打开支付宝扫一扫,即可进行扫码打赏哦