0%

HDU 4821 字符串hash

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4821

题目大意:

给定一个字符串 S 问 S 中有几个字串可以满足如下条件:

  1. 长度为 m * l;
  2. 可以分为 m 个长度为 l 的字串,并且这些字串不完全相同;

解题思路:

总体的思路可以是一个 l 中的每一个位置 i 作为起始位置,然后判断 i + l, i + 2 * l, i + 3 * l, …..  这些区间的字串是否相同,所以题目的重点在于子串的判重,跟队友组了场 virtual TLE 和 MLE 到死。后来知道了 BKDHash 这个东西 (关于 BKDHash 请参照 http://blog.csdn.net/wanglx_/article/details/40400693 ), 他可以在 O(n) 复杂度的预处理下,O(1) 查询任意字串的 Hash 值,这里具体的使用方法是:

  1. 先设一个种子 Base, 通常为质数;
  2. 打表打出 nbase[i] 表示 Base 的 i 次方;
  3. 打表打出给定字符串的前 i 个字符组成的字符串的哈希值 Hash[i] = Hash[i - 1] * Base + str[i - 1] - ‘a’ + 1;
  4. 上述预处理后对于每段[l, r) 中间的字符串,他的哈希值为 Hash[r] - Hash[l] * nbase[r - l];
  5. 特殊说明:所有的数组的类型均为 unsigned long long 时不用模除;

参考代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
/*=============================================================================
# Author: Datasource
# Last modified: 2015-10-03 22:10
# Filename: 4821.cpp
# Description:
=============================================================================*/

#include <map>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#include <cstdio>
#include <string>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define lson l, mid, ls
#define rson mid, r, rs
#define ls (rt << 1) + 1
#define rs (rt << 1) + 2
using namespace std;
const int MAXN = 100010;

struct Hash{

unsigned long long Base;
unsigned long long nbase[MAXN];
unsigned long long Has[MAXN];

void init(char *s){
Base = 131;
nbase[0] = 1;
Has[0] = 0;
int len = strlen(s);
for(int i = 1; i <= len; i++){
nbase[i] = nbase[i - 1] * Base;
Has[i] = Has[i - 1] * Base + s[i - 1] - 'a' + 1;
}
}

unsigned long long getHash(int l, int r){
return (Has[r] - Has[l] * nbase[r - l]);
}
}_hash;

int m, l;
char ch[MAXN];
map <unsigned long long, int> ma;

int main(){
while(scanf("%d%d", &amp;m, &amp;l) != EOF){
scanf(" %s", ch);
_hash.init(ch);
int ans = 0;
int len = strlen(ch);

for(int i = 0; i < l; i++){
ma.clear();
if(i + m * l >= len) break;
for(int j = i; j < i + m * l; j += l){
unsigned long long hv = _hash.getHash(j, j + l);
ma[hv]++;
}
if(ma.size() == m) ans++;
for(int j = i; j < len; j += l){
unsigned long long behv = _hash.getHash(j, j + l);
ma[behv]--;
if(ma[behv] == 0) ma.erase(behv);
if(j + l * (m + 1) - 1 >= len) break;
unsigned long long hv = _hash.getHash(j + l * m, j + l * (m + 1));
ma[hv]++;
if(ma.size() == m) ans++;
}
}
printf("%d\n", ans);
}
return 0;
}