c++ 如何实现Trie删除功能而不产生重叠错误

cunj1qz1  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(97)

我试图在c++中实现一个Trie来解决这个问题https://codeforces.com/problemset/problem/706/D,除了删除函数,我已经把所有的东西都写下来了。出于某种原因,即使我的代码进行检查以确保我们没有删除必要的元素,它仍然这样做。我甚至遵循了数字海洋trie的解释,但这也没有帮助。这导致测试用例#8的WA。
这是我当前的代码,如果你向下滚动一点,你会发现删除函数和它的帮助器方法。

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
 
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const ll INF = 100000000000;
const ll MOD = 1000000007;
const int MAX_N = 1000005;
 
using namespace __gnu_pbds;
template<typename T> using ordered_set = tree<T, null_type, less<T>, 
rb_tree_tag, tree_order_statistics_node_update>;
 
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());

typedef struct trie_node trie_node;

char *decimal_binary(int val) {
    char *bin_rep = (char*)(calloc(32, sizeof(char)));
    for(int i = 31; i >= 0; --i) {
        if(val & (1 << i)) {
            bin_rep[i] = '1';
        } else {
            bin_rep[i] = '0';
        }
    }
    reverse(bin_rep, bin_rep+32);
    return bin_rep;
}

int binary_decimal(char *bs) {
    int sum = 0;
    for(int i = 31, j = 0; i >= 0; --i, ++j) {
        sum += ((int)(bs[i]-'0'))*pow(2, j);
    }
    return sum;
}
 
struct trie_node {
    trie_node *children[2];
    bool is_leaf = false;
};
 
trie_node *make_node() {
    trie_node *node = new trie_node;
    for(int i = 0; i < 2; ++i) {
        node->children[i] = NULL;
    }
    node->is_leaf = false;
    return node;
}
 
void unload_node(trie_node *node) {
    for(int i = 0; i < 2; ++i) {
        if(node->children[i] != NULL) {
            unload_node(node->children[i]);
        } else {
            continue;
        }
    }
    free(node);
}
 
trie_node *insert_node(trie_node *root, char *bs) {
    trie_node *temp = root;
    for(int i = 0; bs[i] != '\0'; ++i) {
        int idx = (int)(bs[i]-'0');
        if(temp->children[idx] == NULL) {
            temp->children[idx] = make_node();
        }
        temp = temp->children[idx];
    }
    temp->is_leaf = true;
    return root;
}

bool check_leaf(trie_node *root, char *bs) {
    trie_node *temp = root;
    for(int i = 0; bs[i]; ++i) {
        int idx = (int)bs[i]-'0';
        if(temp->children[idx] != NULL) {
            temp = temp->children[idx];
        }
    }
    return temp->is_leaf;
}

int earliest_branch(trie_node *root, char *bs) {
    trie_node *temp = root;
    int n = strlen(bs);
    if(n == 0) return 0;
    int last_idx = 0;
    for(int i = 0; i < n; ++i) {
        int idx = bs[i]-'0';
        if(temp->children[idx]) {
            for(int j = 0; j < 2; ++j) {
                if(j != idx && temp->children[j] != NULL) {
                    last_idx = i+1;
                    break;
                }
            }
            temp = temp->children[idx];
        }
    }
    return last_idx;
}

char *longest_prefix(trie_node *root, char *bs) {
    if(!bs || bs[0] == '\0') return NULL;
    
    int n = strlen(bs);
    char *lgt_prefix = (char*)(calloc(n+1, sizeof(char)));
    for(int i = 0; bs[i] != '\0'; ++i) {
        lgt_prefix[i] = bs[i];
    }
    lgt_prefix[n] = '\0';
    
    int branch_idx = earliest_branch(root, lgt_prefix)-1;
    if(branch_idx >= 0) {
        lgt_prefix[branch_idx] = '\0';
        lgt_prefix = (char*)(realloc(lgt_prefix, (branch_idx+1)*sizeof(char)));
    }
    
    return lgt_prefix;
}

trie_node *delete_node(trie_node *root, char *bs) {
    if(!root) return NULL;
    if(!bs || bs[0] == '\0') return root;
    if(!check_leaf(root, bs)) return root;
    
    trie_node *temp = root;
    char *lgt_prefix = longest_prefix(root, bs);
    if(lgt_prefix[0] == '\0') {
        free(lgt_prefix);
        return root;
    }
    int pos;
    for(pos = 0; lgt_prefix[pos] != '\0'; ++pos) {
        int idx = (int)lgt_prefix[pos]-'0';
        if(temp->children[idx] != NULL) {
            temp = temp->children[idx];
        } else {
            free(lgt_prefix);
            return root;
        }
    }
    int n = strlen(bs);
    for(; pos < n; ++pos) {
        int idx = (int)bs[pos]-'0';
        if(temp->children[idx]) {
            trie_node *extra = temp->children[idx];
            temp->children[idx] = NULL;
            unload_node(extra);
        }
    } 
    free(lgt_prefix);
    return root;
}
 
char *search_trie(trie_node *root, char *bs) {
    char *res = (char*)(calloc(32, sizeof(char)));
    for(int i = 0; i < 32; ++i) {
        res[i] = '0';
    }
    trie_node *temp = root;
    for(int i = 0; i < 32; ++i) {
        int idx = (((int)(bs[i]-'0'))+1)%2;
        if(temp->children[idx] != NULL) {
            res[i] = '1';
            temp = temp->children[idx];
        } else if(temp->children[(idx+1)%2] != NULL){
            temp = temp->children[(idx+1)%2];
        } else {
            break;
        }
    }
    return res;
}

int main() {
    cin.tie(0)->sync_with_stdio(0);
    int t;
    cin >> t;
    
    trie_node *root = make_node();
    map<int, int> cnt;
    char *tmpbs = decimal_binary(0);
    root = insert_node(root, tmpbs);
    while(t--) {
        char type; int val;
        cin >> type >> val;
        char *bs = decimal_binary(val);
        if(type == '+') {
            if(cnt[val] == 0) {
                root = insert_node(root, bs);
            }
            ++cnt[val];
        } else if(type == '-') {
            --cnt[val];
            if(cnt[val] == 0) {
                root = delete_node(root, bs);
            }
        } else {
            char *res = search_trie(root, bs);
            int ans = binary_decimal(res);
            cout << ans << "\n";
        }
    }
    unload_node(root);
    
    return 0;
}

谢谢你,谢谢
UPD 1:这里有一个测试用例,我的代码失败了

14
? 1
+ 1
+ 7
? 2
+ 3
? 1
? 6
+ 4
+ 8
- 8
+ 6
+ 6
- 6
? 3

我的输出:

1
5
6
7
5

正确输出:

1
5
6
7
7
ldfqzlk8

ldfqzlk81#

我不知道你到底在做什么,但是在我看来,你在从trie中删除一个元素时丢失了一个键,这就是这个键是trie中另一个长键的前缀键。这里有一个帖子,你可以查看以获得更多细节:https://discuss.boardinfinity.com/t/trie-delete/6253

相关问题