1 solutions

  • 0
    @ 2024-2-20 8:41:42
    #include "deliveries.h"
    
    #include <vector>
    #include <set>
    #include <iostream>
    
    #define MAXN 101000
    
    using namespace std;
    
    long long N, allSumTxW, allSumW, segId, uFrom, uValue, flippedWT;
    vector<int> edge[MAXN];
    long long W[MAXN];
    long long T[MAXN];
    int parent[MAXN];
    int num[MAXN];
    long long sumT[MAXN];
    bool heavy[MAXN];
    int seg[MAXN];
    int segPos[MAXN];
    vector<int> segNodes[MAXN];
    int nodes[MAXN];
    int segParent[MAXN];
    long long sumW[MAXN];
    set<pair<long long, int>> maxSumW[MAXN];
    
    struct node{
    	int id, a, b;
    	long long w, l, sumWT, sumL, sumT;
    	node* leftC;
    	node* rightC;
    };
    node* trees[MAXN];
    
    node* build(int x, int y){
    	node* p = new node;
    	node* lc = nullptr;
    	node* rc = nullptr;
    
    	p->a = x; p->b = y; p->id = -1; p->l = 0;
    	p->sumWT = 0; p->sumL = 0; p->sumT = 0;
    	if(x==y){
    		p->id = nodes[x-1];
    		p->w = sumW[p->id];
    		p->sumT = T[p->id];
    		p->sumWT = p->sumT * sumW[p->id];
    	} else {
    		lc = build(x,(x+y)/2);
    		rc = build((x+y)/2+1,y);
    		p->w = rc->w;
    		p->sumWT = lc->sumWT + rc->sumWT;
    		p->sumT = lc->sumT + rc->sumT;
    	}
    	p->leftC = lc; p->rightC = rc;
    	return p;
    }
    
    long long value(node* p){
    	return p->w+p->l;
    }
    long long sumValue(node* p){
    	return p->sumWT + p->sumL * p->sumT;
    }
    void push_down(node* p){
    	(p->leftC)->l += p->l;
    	(p->rightC)->l += p->l;
    	p->l = 0;
    	(p->leftC)->sumL += p->sumL;
    	(p->rightC)->sumL += p->sumL;
    	p->sumL = 0;
    }
    void update_node(node* p){
    	p->w = value(p->rightC);
    	p->sumWT = sumValue(p->leftC) + sumValue(p->rightC);
    }
    
    void update(node* p){
    	if(uFrom <= p->a){
    		p->l += uValue;
    		p->sumL += uValue;
    		return;
    	}
    
    	push_down(p);
    
    	if(uFrom <= (p->a+p->b)/2)
    		update(p->leftC);
    	update(p->rightC);
    
    	update_node(p);
    }
    
    int findC(node* p){
    	if(p->a==p->b){
    		flippedWT += sumValue(p);
    		return p->id;
    	}
    
    	push_down(p);
    	update_node(p);
    
    	if(value(p->leftC) >= (allSumW + 1)/2){
    		flippedWT += sumValue(p->rightC);
    		return findC(p->leftC);
    	}
    	return findC(p->rightC);
    }
    
    void preDfs(int x){
    	allSumW += W[x];
    	allSumTxW += sumT[x] * W[x];
    	sumW[x] = W[x];
    	num[x] = 1;
    	for(int i:edge[x]){
    		parent[i] = x;
    		sumT[i] = sumT[x] + T[i];
    		preDfs(i);
    		num[x] += num[i];
    		sumW[x] += sumW[i];
    	}
    }
    
    void dfs(int x){
    	for(int i:edge[x]){
    		dfs(i);
    	}
    	heavy[x] = (x != 0 && num[x] >= (num[parent[x]]+1)/2);
    	if(seg[x]==-1)
    		seg[x] = segId++;
    	if(heavy[x]){
    		seg[parent[x]] = seg[x];
    	} else if(x!=0){
    		maxSumW[parent[x]].insert({sumW[x],seg[x]});
    		segParent[seg[x]] = parent[x];
    	}
    	segNodes[seg[x]].push_back(x);
    	segPos[x] = segNodes[seg[x]].size();
    }
    
    void build_trees(){
    	for(int i=0; i<segId; i++){
    		int count = 0;
    		for(int j:segNodes[i])
    			nodes[count++] = j;
    		trees[i] = build(1,count);
    	}
    }
    
    void updateLine(int x, long long diff){
    	uValue = diff;
    	while(x!=-1){
    		int sx = seg[x], px = segParent[sx];
    
    		if(px!=-1)
    			maxSumW[px].erase({value(trees[sx]),sx});
    
    		uFrom = segPos[x];
    		update(trees[sx]);
    		
    		if(px!=-1)
    			maxSumW[px].insert({value(trees[sx]),sx});
    		
    		x = px;
    	}
    }
    
    int findCentroid(int sx){
    	int c = findC(trees[sx]);
    	auto it = maxSumW[c].end();
    	if(it!=maxSumW[c].begin()){
    		it--;
    		if((*it).first >= (allSumW+1)/2){
    			return findCentroid((*it).second);
    		}
    	}
    	return c;
    }
    
    void init(int NN, std::vector<int> UU, std::vector<int> VV, std::vector<int> TT, std::vector<int> WW){
    	N = NN;
    	vector<vector<pair<int,int>>> e(N);
    	for (int i=0; i+1<N; ++i) {
    		e[UU[i]].push_back({VV[i], TT[i]});
    		e[VV[i]].push_back({UU[i], TT[i]});
    	}
    	vector<int> AA(N, -1), q = {0};
    	vector<int> TTT(N, 0);
    	for (int i = 0; i < (int)q.size(); ++i) {
    		int u = q[i];
    		for (auto [v, t] : e[u]) {
    			if (AA[u] == v) continue;
    			AA[v] = u;
    			TTT[v] = t;
    			q.push_back(v);
    		}
    	}
    
    	for(int i=0; i<N-1; i++){
    		edge[AA[i+1]].push_back(i+1);
    		T[i+1] = TTT[i+1];
    	}
    	for(int i=0; i<N; i++){
    		W[i] = WW[i];
    		seg[i] = -1;
    		segParent[i] = -1;
    	}
    
    	W[0]++;
    	parent[0] = -1;
    	preDfs(0); dfs(0);
    	build_trees();
    }
    
    long long max_time(int S, int X) {
    	if(S==0) X++;
    	long long diff = X - W[S];
    	W[S] = X;
    	allSumTxW += diff * sumT[S];
    	allSumW += diff;
    
    	updateLine(S, diff);
    
    	flippedWT = 0;
    	int c = findCentroid(seg[0]);
    
    	return 2 * (allSumTxW + allSumW * sumT[c] - 2 * flippedWT);
    }
    

    Information

    ID
    23
    Time
    2000ms
    Memory
    256MiB
    Difficulty
    10
    Tags
    (None)
    # Submissions
    2
    Accepted
    1
    Uploaded By