工作中有遇到一种平均分配问题,贷后环节给催收员分案时,既要保证每个人分到的案件量近似,又要保证每个人分到的案件总金额近似,这是一个NP难问题。
下面是一个近似求解的方法。算法思路为:
假设总案件量为N,需要分给m个人
- 先将所有案件随机均分为m组,如果N不能被m整除,则给原案件列表后补0,补足长度至能被m整除为止
- 进入大循环,寻找当前金额总和最大的组和最小的组
- 进入小循环,随机交换两组中的一个元素
- 判断极差是否减小,如果减小,则跳出小循环回到步骤2,如果未减小,则继续小循环回到步骤3
- 当极差小于等于设定值时,跳出大循环,结束
import logging
from itertools import count
class LoggerFactory:
"""Factory to create logger."""
@staticmethod
def stream(name):
"""Stream logger."""
assert hasattr(name, '__name__') or isinstance(name, str), 'Input must be string or has attribute `__name__`.'
if hasattr(name, '__name__'):
name = name.__name__.upper()
else:
name = name.upper()
logger = logging.getLogger(name)
if not logger.handlers:
logger.setLevel(logging.DEBUG)
fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
formatter = logging.Formatter(fmt)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
logger.addHandler(sh)
return logger
@staticmethod
def file(name):
"""File logger."""
logger = logging.getLogger(name)
if not logger.handlers:
logger.setLevel(logging.DEBUG)
fmt = '%(asctime)s - %(levelname)s - %(message)s'
formatter = logging.Formatter(fmt)
fh = logging.FileHandler(f'{name}.log', encoding='utf-8')
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
@staticmethod
def both(name):
"""Stream and file logger."""
logger = logging.getLogger(name)
if not logger.handlers:
logger.setLevel(logging.DEBUG)
fmt = '%(asctime)s - %(levelname)s - %(message)s'
formatter = logging.Formatter(fmt)
fh = logging.FileHandler(f'{name}.log', encoding='utf-8')
fh.setFormatter(formatter)
logger.addHandler(fh)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
logger.addHandler(sh)
return logger
def balance(index, value, group, drift=100, max_iter=500, max_exchange=10000, logger=None, verbose=False):
"""
平均分配问题,保证均分结果中每组个数与值的总和都大致相似
Parameters
----------
index: 1D array
与值相对应的索引
value: 1D array
需要进行均分的值
group: int
均分的组数
drift: float, optional
允许分组结果的极差最大值
max_iter: int, optional
最大迭代次数
max_exchange: int, optional
最大交换次数
logger: logger, optional
日志器
verbose: boolean
是否打印详情
Returns
-------
index : 1D array
均分的值结果所对应的索引
value : 1D array
均分的值结果
"""
import numpy as np
if not logger:
logger = LoggerFactory.stream(balance)
index = index.copy()
value = value.copy()
if len(value) % group != 0:
index = np.append(index, [0] * (group - len(value) % group))
value = np.append(value, [0] * (group - len(value) % group))
ncol = len(value) // group
index = np.reshape(index, (group, ncol))
value = np.reshape(value, (group, ncol))
diff = value.sum(axis=1).ptp()
logger.info('=' * 50)
for epoch in count(1, step=1):
if diff <= drift:
logger.info('Success.')
break
elif epoch > max_iter:
logger.warning('Reach max iter.')
break
max_group_index = value.sum(axis=1).argmax()
min_group_index = value.sum(axis=1).argmin()
for i in range(max_exchange):
x = np.random.randint(0, ncol)
y = np.random.randint(0, ncol)
index[max_group_index, x], index[min_group_index, y] = index[min_group_index, y], index[max_group_index, x]
value[max_group_index, x], value[min_group_index, y] = value[min_group_index, y], value[max_group_index, x]
if value.sum(axis=1).ptp() < diff:
diff = value.sum(axis=1).ptp()
if verbose:
logger.info(f'Iter {epoch}, done after {i + 1} exchange.')
break
else:
logger.warning(f'Iter {epoch}, reach max exchange.')
logger.info(f'Total value in each group:\n{value.sum(axis=1)}')
logger.info(f'Total number in each group:\n{[len(x) for x in [[x for x in row if x != 0] for row in index.tolist()]]}')
return index, value