【数据结构】线段树

线段树(Segment Tree)是一种基于分治思想的二叉树结构,用于统计区间信息。相比于按二进制位进行区间划分的树状数组(Binary Indexed Trees),线段树更加通用。

线段树满足线段树性质(自创命名233):
①线段树每个结点都是一个区间。
②线段树具有唯一的根结点,代表整个区间的统计范围,如$[1,N]$。
③线段树的每个叶结点都代表一个长度为1的元区间$[x,x]$。
④对于每个内部结点$[l,r]$,它的左儿子是$[l,mid]$,右儿子是$[mid+1,r]$,其中$mid=\lfloor (l+r)\div 2 \rfloor$。

显然对于一棵线段树,除去树的最后一层,就是一棵满二叉树,深度为$O(\log N)$。可以用类似于二叉堆类似的方法给结点编号:
①根结点编号为1。
②编号为$x$的结点的左孩子编号为$x << 1(x \times 2)$,右孩子为$(x << 1)+1(x \times 2 + 1)$,父亲编号为$x >> 1(\lfloor x \div 2 \rfloor)$。

这样我们就可以用一个数组来保存整棵线段树,就像二叉堆一样。
显然空间复杂度为$O(4 \times N)$。

线段树建树

class SegmentTree{
    public:int l,r;
    public:int data;
};
SegmentTree t[SIZE * 4];

inline int LEFT(int x){
    return x << 1;
}

inline int RIGHT(int x){
    return LEFT(x) + 1;
}

void build(int p,int l,int r){
    t[p].l = l;t[p].r = r;
    if(l == r){
        t[p].data = a[l];
        return;
    }
    int mid = (l + r) / 2;
    build(LEFT(p),l,mid);
    build(RIGHT(p),mid + 1,r);
    t[p].data = max(t[LEFT(p)].data,t[RIGHT(p)].data);
}

build(1,1,n);//call the build tree function

每个叶结点$[i,i]$保存$A[i]$的值。线段树的二叉树结构使其很方便地从上往下传递信息。
以区间最大值为例,记$data (l,r)=\max_{l\leq i\leq r}{ A[i]}$,显然$data (l,r)=\max (data (l,mid),data (mid + 1,r))$。

线段树单点修改

单点修改就形如$C\ x\ v$的指令,表示把$A[x]$的值修改为$v$。

void change(int p,int x,int v){
    if(t[p].l == t[p].r){
        t[p].data = v;
        return;
    }
    int mid = (t[p].l + t[p].r) / 2;
    if(x <= mid){
        change(LEFT(p),x,v);
    }else{
        change(RIGHT(p),x,v);
    }
    t[p].data = max(t[LEFT(p)].data,t[RIGHT(p)].data);
}
change(1,x,v);

在线段树中,根结点是执行各种操作的入口。我们从根结点出发,递归找到代表区间$[x,x]$的叶结点,然后从下往上更新$[x,x]$以及它所有祖先结点上保存的信息。时间复杂度为$O(\log n)$

线段树的区间查询

区间查询是形如$Q\ l\ r$的指令,例如查询序列$A$在区间$[l,r]$上的最大值。我们只需从根结点开始,递归执行以下过程:

①若$[l,r]$完全覆盖了当前结点代表区间,则立即回溯,并且该结点$data$值为候选答案。

②若左儿子与$[l,r]$有重叠部分,则递归访问左儿子。

③若右儿子与$[l,r]$有重叠部分,则递归访问右儿子。

int ask(int p,int l,int r){
    if(l <= t[p].l && r >= t[p].r){
        return t[p].data;
    }
    int mid = (t[p].l + t[p].r) / 2;
    int MINUS_INF = -0x7fffffff;
    if(l <= mid){
        MINUS_INF = max(MINUS_INF,ask(LEFT(p),l,r));
    }
    if(r > mid){
        MINUS_INF = max(MINUS_INF,ask(RIGHT(p),l,r));
    }
    return MINUS_INF;
}
printf("%d\n",ask(1,l,r));