我试图在c++中实现一个Trie来解决这个问题https://codeforces.com/problemset/problem/706/D,除了删除函数,我已经把所有的东西都写下来了。出于某种原因,即使我的代码进行检查以确保我们没有删除必要的元素,它仍然这样做。我甚至遵循了数字海洋trie的解释,但这也没有帮助。这导致测试用例#8的WA。
这是我当前的代码,如果你向下滚动一点,你会发现删除函数和它的帮助器方法。
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const ll INF = 100000000000;
const ll MOD = 1000000007;
const int MAX_N = 1000005;
using namespace __gnu_pbds;
template<typename T> using ordered_set = tree<T, null_type, less<T>,
rb_tree_tag, tree_order_statistics_node_update>;
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
typedef struct trie_node trie_node;
char *decimal_binary(int val) {
char *bin_rep = (char*)(calloc(32, sizeof(char)));
for(int i = 31; i >= 0; --i) {
if(val & (1 << i)) {
bin_rep[i] = '1';
} else {
bin_rep[i] = '0';
}
}
reverse(bin_rep, bin_rep+32);
return bin_rep;
}
int binary_decimal(char *bs) {
int sum = 0;
for(int i = 31, j = 0; i >= 0; --i, ++j) {
sum += ((int)(bs[i]-'0'))*pow(2, j);
}
return sum;
}
struct trie_node {
trie_node *children[2];
bool is_leaf = false;
};
trie_node *make_node() {
trie_node *node = new trie_node;
for(int i = 0; i < 2; ++i) {
node->children[i] = NULL;
}
node->is_leaf = false;
return node;
}
void unload_node(trie_node *node) {
for(int i = 0; i < 2; ++i) {
if(node->children[i] != NULL) {
unload_node(node->children[i]);
} else {
continue;
}
}
free(node);
}
trie_node *insert_node(trie_node *root, char *bs) {
trie_node *temp = root;
for(int i = 0; bs[i] != '\0'; ++i) {
int idx = (int)(bs[i]-'0');
if(temp->children[idx] == NULL) {
temp->children[idx] = make_node();
}
temp = temp->children[idx];
}
temp->is_leaf = true;
return root;
}
bool check_leaf(trie_node *root, char *bs) {
trie_node *temp = root;
for(int i = 0; bs[i]; ++i) {
int idx = (int)bs[i]-'0';
if(temp->children[idx] != NULL) {
temp = temp->children[idx];
}
}
return temp->is_leaf;
}
int earliest_branch(trie_node *root, char *bs) {
trie_node *temp = root;
int n = strlen(bs);
if(n == 0) return 0;
int last_idx = 0;
for(int i = 0; i < n; ++i) {
int idx = bs[i]-'0';
if(temp->children[idx]) {
for(int j = 0; j < 2; ++j) {
if(j != idx && temp->children[j] != NULL) {
last_idx = i+1;
break;
}
}
temp = temp->children[idx];
}
}
return last_idx;
}
char *longest_prefix(trie_node *root, char *bs) {
if(!bs || bs[0] == '\0') return NULL;
int n = strlen(bs);
char *lgt_prefix = (char*)(calloc(n+1, sizeof(char)));
for(int i = 0; bs[i] != '\0'; ++i) {
lgt_prefix[i] = bs[i];
}
lgt_prefix[n] = '\0';
int branch_idx = earliest_branch(root, lgt_prefix)-1;
if(branch_idx >= 0) {
lgt_prefix[branch_idx] = '\0';
lgt_prefix = (char*)(realloc(lgt_prefix, (branch_idx+1)*sizeof(char)));
}
return lgt_prefix;
}
trie_node *delete_node(trie_node *root, char *bs) {
if(!root) return NULL;
if(!bs || bs[0] == '\0') return root;
if(!check_leaf(root, bs)) return root;
trie_node *temp = root;
char *lgt_prefix = longest_prefix(root, bs);
if(lgt_prefix[0] == '\0') {
free(lgt_prefix);
return root;
}
int pos;
for(pos = 0; lgt_prefix[pos] != '\0'; ++pos) {
int idx = (int)lgt_prefix[pos]-'0';
if(temp->children[idx] != NULL) {
temp = temp->children[idx];
} else {
free(lgt_prefix);
return root;
}
}
int n = strlen(bs);
for(; pos < n; ++pos) {
int idx = (int)bs[pos]-'0';
if(temp->children[idx]) {
trie_node *extra = temp->children[idx];
temp->children[idx] = NULL;
unload_node(extra);
}
}
free(lgt_prefix);
return root;
}
char *search_trie(trie_node *root, char *bs) {
char *res = (char*)(calloc(32, sizeof(char)));
for(int i = 0; i < 32; ++i) {
res[i] = '0';
}
trie_node *temp = root;
for(int i = 0; i < 32; ++i) {
int idx = (((int)(bs[i]-'0'))+1)%2;
if(temp->children[idx] != NULL) {
res[i] = '1';
temp = temp->children[idx];
} else if(temp->children[(idx+1)%2] != NULL){
temp = temp->children[(idx+1)%2];
} else {
break;
}
}
return res;
}
int main() {
cin.tie(0)->sync_with_stdio(0);
int t;
cin >> t;
trie_node *root = make_node();
map<int, int> cnt;
char *tmpbs = decimal_binary(0);
root = insert_node(root, tmpbs);
while(t--) {
char type; int val;
cin >> type >> val;
char *bs = decimal_binary(val);
if(type == '+') {
if(cnt[val] == 0) {
root = insert_node(root, bs);
}
++cnt[val];
} else if(type == '-') {
--cnt[val];
if(cnt[val] == 0) {
root = delete_node(root, bs);
}
} else {
char *res = search_trie(root, bs);
int ans = binary_decimal(res);
cout << ans << "\n";
}
}
unload_node(root);
return 0;
}
谢谢你,谢谢
UPD 1:这里有一个测试用例,我的代码失败了
14
? 1
+ 1
+ 7
? 2
+ 3
? 1
? 6
+ 4
+ 8
- 8
+ 6
+ 6
- 6
? 3
我的输出:
1
5
6
7
5
正确输出:
1
5
6
7
7
1条答案
按热度按时间ldfqzlk81#
我不知道你到底在做什么,但是在我看来,你在从trie中删除一个元素时丢失了一个键,这就是这个键是trie中另一个长键的前缀键。这里有一个帖子,你可以查看以获得更多细节:https://discuss.boardinfinity.com/t/trie-delete/6253