Problem
敌方 $n$ 台伤害 $a_i$,耐久 $d_i$ 的武器,我方伤害 $ATK$,按如下方式进行操作
- 选择任意一台敌方武器造成 $ATK$ 伤害,若造成伤害后其生命值 小于0 则毁坏
- 敌方每台未毁坏的武器对我方造成等于其攻击力的伤害
若开始时能秒杀两台武器,问我方最少受到多少伤害?
Thought
首先计算一下如何安排攻击敌方的顺序
每个武器能承受 $\left \lceil \frac{d}{ATK}\right\rceil$ 次攻击,设为 $t$,能够攻击 $t-1$ 次
假设当前有两台武器 $(a_1,t_1),(a_2,t_2)$
先攻击 $1$ 号武器的总损失是
$$ (t_1-1)*a_1+(t_1+t_2-1)*a_2 $$
先攻击 $2$ 号武器的总损失是
$$ (t_2-1)*a_2+(t_1+t_2-1)*a_1 $$
若应该先攻击 $1$ 号武器,则
$$ (t_1-1)*a_1+(t_1+t_2-1)*a_2\leq(t_2-1)*a_2+(t_1+t_2-1)*a_1 $$
变形可得
$$ \frac{t_1}{a_1}\leq\frac{t_2}{a_2} $$
同理可得应该先攻击 $2$ 号武器时有
$$ \frac{t_1}{a_1}\ge\frac{t_2}{a_2} $$
也就是说需要以 $t/a$ 为关键字从小到大排序
那么去掉哪两个点最优呢?
设去掉第 $x$ 个点贡献为 $c_x=(t_x\sum_{i=x}^na_i)-a_x$
那么若同时去掉两个点 $i<j$ 会导致重复计算 $t_i*a_j$ 的贡献
考虑先确定一个点 $i$,那么另外一个点 $j$ 就应该是使得 $c_j-t_i*a_j$ 最大的点
将其看做是平面上一个斜率为 $-a$,截距为 $c$ 的点,那么只需要求在 $t_i$ 处的最大值
从后往前加点,用李超线段树维护
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int read() {
int ans = 0 , flag = 1;
char ch = getchar();
while(ch > '9' || ch < '0') {if(ch == '-') flag = - flag; ch = getchar();}
while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
return ans * flag;
}
const ll N = 300010;
struct fighter {
ll t , a , c;
bool operator < (const fighter x) const {return (double)t / (double)a < (double)x.t / (double)x.a;}
}f[N];
int n , ATK;
int dif , T[N];
int suma , sumt;
ll sum;
#define L (x << 1)
#define R (x << 1 | 1)
#define mid ((l + r) >> 1)
ll cross(int x , int pos) {return f[x].c - T[pos] * f[x].a;}
struct Line {
int id[N << 2];
void insert(int x , int l , int r , int p) {
if(id[x] == 0) id[x] = p;
else {
if(cross(id[x] , mid) < cross(p , mid))
swap(id[x] , p);
if(l == r) return;
if(cross(id[x] , l) < cross(p , l))
insert(L , l , mid , p);
if(cross(id[x] , r) < cross(p , r))
insert(R , mid + 1 , r , p);
}
}
ll query(int x , int l , int r , int p) {
if(!id[x]) return 0;
ll ans = cross(id[x] , p);
if(p <= mid) return max(ans , query(L , l , mid , p));
else return max(ans , query(R , mid + 1 , r , p));
}
}t;
int main() {
freopen("in" , "r" , stdin);
n = read(); ATK = read();
for(ll i = 1 ; i <= n ; ++ i) {
int a = read() , d = read();
f[i].a = a; f[i].t = d / ATK + (d % ATK ? 1 : 0);
T[i] = f[i].t;
}
sort(f + 1 , f + n + 1);
sort(T + 1 , T + n + 1);
dif = unique(T + 1 , T + n + 1) - T - 1;
for(int i = 1 ; i <= n ; ++ i) {
sumt += f[i].t;
sum += f[i].a * (sumt - 1);
f[i].c += f[i].a * (sumt - 1);
}
for(int i = n ; i >= 1 ; -- i) {
f[i].c += suma * f[i].t;
f[i].t = lower_bound(T + 1 , T + dif + 1 , f[i].t) - T;
suma += f[i].a;
}
ll ans = 0;
for(int i = n ; i >= 1 ; -- i) {
ans = max(ans , f[i].c + t.query(1 , 1 , dif , f[i].t));
t.insert(1 , 1 , dif , i);
}
printf("%lld\n" , sum - ans);
return 0;
}