斜率优化

数学——人类精神虐待

Posted by Monad on July 20, 2018

引入

开门见山,斜率优化是用来优化如下的 DP 转移方程的。

$$ f(i) = \max \{ f(j) + g(i) \times x(j) + t(j) + h(i) \} \tag{1} $$

这个方程的特点是,每一个决策 j 的价值不仅与 j 自身有关,还跟当前状态 i 有关。
如果没有 $ g(i) \times x(j) $ 的情况下,我们就可以使用单调队列优化。但是如果有 $ g(i) \times x(j) $ 一项,我们就不能只根据 j 判断哪一个更优。所以我们还要考虑一种更好的方法。

例子

上面的方程有点抽象,先引入一道例题吧:BZOJ 3437
题意就是有 n 个点,每一个点 i 有一个价值 a[i] 和点权 b[i]。我们可以选任意多个点,但是第 n 个必须选。 但是选一个点有一个代价,首先要付出 a[i] 的价格,然后从上一个已选的点 j 开始,付出 $ \sum_{k=j+1}^i b[k] \times (i - k) $ 的代价。问如何选这些点,使得总代价最小。

设最后选的点为 i,其最小代价为 f[i],那么有:

$$ f[i] = \min \{ f[j] + a[i] + \sum_{k=j+1}^i (i-k)b[k] \} \tag{2} $$

由 2 展开,得

$$ f[i] = \min \{ f[j] + a[i] + i\sum_{k=j+1}^i b[k] - \sum_{k=j+1}^i kb[k] \} \tag{3} $$

b[k] 的前缀和为 sum1[k]k*b[k] 的前缀和为 sum2[k]。则

$$ f[i] = \min \{ f[j] + a[i] + i(sum1[i] - sum1[j]) - (sum2[i] - sum2[j]) \tag{4} \} $$

整理,得

$$ f[i] = \min \{ f[j] + a[i] - i \times sum1[j] + sum2[j] + i \times sum1[i] - sum2[i] \tag{5} \} $$

这个转移方程与斜率优化的方程格式相同,其中 $ g(i) = i $, $ x(j) = -sum1[j] $, $ t(j) = sum2[j] $, $ h(i) = -sum2[i] + i * sum1[i] + a[i] $。

斜率优化

分析

如何优化这个式子。

我们可以考虑一个问题,对于同一个 i,任意两个决策 $ j_1 $ 和 $ j_2 $ 那个更优,这里假设 $ j_1 < j_2 $。

如果 $ j_2 $ 比 $ j_1 $ 优的话,那么有

$$ f(j_1) + g(i) \times x(j_1) + t(j_1) + h(i) > f(j_2) + g(i) \times x(j_2) + t(j_2) + h(i) \tag{6} $$

整理,得

$$ g(i) \times \big( x(j_1) - x(j_2) \big) > - \big( f(j_1) + t(j_1) \big) + \big( f(j_2) + t(j_2) \big) \tag{7} $$

设 $ y(j) = - \big( f(j) + t(j) \big) $,则

$$ g(i) \times \big( x(j_1) - x(j_2) \big) > y(j_1) - y(j_2) \tag{8} $$

把 $ x(j_1) - x(j_2) $ 移至右边,考虑 $ x(j_1) - x(j_2) $ 的正负。

在这道题中,因为 $ j_1 < j_2 $,则 $ x(j_1) > x(j_2) $,所以 $ x(j_1) - x(j_2) > 0 $,则

$$ g(i) > \frac{ y(j_1) - y(j_2) }{ x(j_1) - x(j_2) } \tag{9} $$

如果我们把 $ \big( x(j), y(j) \big) $ 看作一个点,那么 $ \frac{ y(j_1) - y(j_2) }{ x(j_1) - x(j_2) } $ 就是斜率的形式,即点 $ \big( x(j_1), y(j_1) \big) $ 到点 $ \big( x(j_2), y(j_2) \big) $ 的斜率。

当 $ g(i) $ 大于 $ \frac{ y(j_1) - y(j_2) }{ x(j_1) - x(j_2) } $ 时,$ j_2 $ 比 $ j_1 $ 优,反之则反。

我们设这个斜率为 $ k(j_1, j_2) $。

提取

我们把每一个决策所对应的点 $ \big( x(j), y(j) \big) $ 画在图上,并且在相邻的两个点之间连线。

连线

我们观察任意三个连续的点,考虑两种情况。


首先是 $ k(j_1, j_2) \geq k(j_2, j_3) $。

情况 1

所以如果有 $ g(i) > \frac{ y(j_1) - y(j_2) }{ x(j_1) - x(j_2) } = k(j_1, j_2) $,那么肯定有 $ g(i) > k(j_2, j_3) $。

也就是说,只要 $ j_2 $ 比 $ j_1 $ 优,那么 $ j_3 $ 肯定比 $ j_2 $ 优,否则 $ j_1 $ 比 $ j_2 $ 优。所以无论如何,$ j_2 $ 都不是最优,所以 $ j_2 $ 就没有保留的价值,可以删掉。


相反,如果是 $ k(j_1, j_2) < k(j_2, j_3) $ 时:

情况 2

那么当 $ g(i) > \frac{ y(j_1) - y(j_2) }{ x(j_1) - x(j_2) } = k(j_1, j_2) $ 时,不一定有 $ g(i) > k(j_2, j_3) $。这时 $ j_2 $ 就有保留的价值。


这样的过程重复地进行,把所有没用的决策都删掉,最后我们就会得到一个下凸壳。

下凸壳

也就是说,我们要维护一个符合 $ k(j_1, j_2) < k(j_2, j_3) $ 的单调队列。当一个 j 加进来时,我们只需要按照单调队列的方法从右边把不符合的点一个个删掉即可。

代码如下所示:

while (qh < qt - 1 && ((getY(q[qt-1]) - getY(q[qt-2])) / (getX(q[qt-1]) - getX(q[qt-2]))) >=
	                  ((getY(i      ) - getY(q[qt-1])) / (getX(i      ) - getX(q[qt-1]))))
	qt --;
q[qt++] = i;

但是除法有精度问题(并且速度较慢),我们可以将其转化为乘法。

while (qh < qt - 1 && ((getY(q[qt-1]) - getY(q[qt-2])) * (getX(i) - getX(q[qt-1]))) >=
	                  ((getX(q[qt-1]) - getX(q[qt-2])) * (getY(i) - getY(q[qt-1]))))
	qt --;
q[qt++] = i;

找决策

那么现在我们有了一个稳点的下凸壳了,但是我们要如何找出对于当前状态 i 最优的决策呢?

我们对于 $ g(i) $ 的增减情况进行讨论。

首先像例题那样,$ g(i) $ 是递增的,所以由式 9 可得,如果对于当前的 g(i),凸壳中(开头的)第二个决策就优于第一个决策,那么对于后面的每一个 g(i),第二个决策都优于第一个决策,所以第一个点就没有存在的必要。

那么每次我们都判断第二个决策是否比第一个优,如果是,就删除第一个点。

代码实现如下(已经把除法转为乘法):

while (qh < qt - 1 && (getY(q[qh+1]) - getY(q[qh])) <= g(i) * (getX(q[qh+1]) - getX(q[qh])))
	qh ++;

然后 q[qh] 就是对于当前的 i 的最优决策。

对于其它情况,如果 $ g(i) $ 递减,那么就从凸壳的右边找。

如果 $ g(i) $ 没有单调性,则只能用二分查找最后一个满足 $ g(i) > \frac{ y(j-1) - y(j) }{ x(j-1) - x(j) } 的 j 即可。

int l=qh, r=qt, mid;
while (l + 1 < r) {
	mid = (l + r) >> 1;
	if ((getY(q[mid]) - getY(q[mid-1])) <= g(i) * (getX(q[mid]) - getX(q[mid-1])))
		l = mid;
	else r = mid;
}
// return l;

时间复杂度

因为每个决策只会进入队列一次,出队也最多一次,如果不用二分查找,那么时间复杂度就是 $ O(n) $,否则就是 $ O(n \log n) $。

其它情况

其实我们的推理过程有几个过程都是用到了“假设”,是基于当前例题而言的。但是在实际的运用中,出题人的毒瘤难以想象,题目有很多种,所以要按照实际情况来。

斜率优化无非就有这么几种不同:$ g(i) $ 的单调性、凸壳的方向、$ x(j) $ 的增减这三个。对它们的不同情况的处理,代入实际情况把上面的过程再推一下是最稳的方式。

最后

最后当然选择把例题的代码放一下啦(逃

// BZOJ3437.cpp
#include <cstdio>

const int MAXN = 1000010;
int n, a[MAXN], b[MAXN], q[MAXN], qh, qt;
long long sum[MAXN][2], f[MAXN];

inline long long getX(const int p) {
	return sum[p][0];
}

inline long long getY(const int p) {
	return f[p] + sum[p][1];
}

void DP() {
	qh = qt = 0;
	q[qt++] = 0;
	for (int i=1, j; i<=n; i++) {
		while (qh < qt - 1 && (getY(q[qh+1]) - getY(q[qh])) <= i * (getX(q[qh+1]) - getX(q[qh])))
			qh ++;

		j = q[qh];
		f[i] = f[j] + (sum[i][0] - sum[j][0]) * i + (sum[j][1] - sum[i][1]) + a[i];

		while (qh < qt - 1 && ((getY(q[qt-1]) - getY(q[qt-2])) * (getX(i) - getX(q[qt-1]))) >=
			                  ((getX(q[qt-1]) - getX(q[qt-2])) * (getY(i) - getY(q[qt-1]))))
			qt --;
		q[qt++] = i;
	}
}

int main() {
	scanf("%d", &n);
	for (int i=1; i<=n; i++)
		scanf("%d", &a[i]);
	for (int i=1; i<=n; i++) {
		scanf("%d", &b[i]);
		sum[i][0] = sum[i-1][0] + b[i];
		sum[i][1] = sum[i-1][1] + (long long)b[i] * i;
	}

	DP();

	printf("%lld\n", f[n]);

	return 0;
}
CC 原创文章采用CC BY-NC-SA 4.0协议进行许可,转载请注明:“转载自:斜率优化