forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
block.py
88 lines (66 loc) · 2.46 KB
/
block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Token blocks."""
from typing import TYPE_CHECKING, Iterator, List, Optional
from vllm.utils import Device
DEFAULT_LAST_ACCESSED_TIME: float = -1
class PhysicalTokenBlock:
"""Represents the state of a block in the KV cache."""
def __init__(
self,
device: Device,
block_number: int,
block_size: int,
block_hash: int,
num_hashed_tokens: int,
) -> None:
self.device = device
self.block_number = block_number
self.block_size = block_size
self.block_hash = block_hash
self.num_hashed_tokens = num_hashed_tokens
self.ref_count = 0
self.last_accessed = DEFAULT_LAST_ACCESSED_TIME
self.computed = False
def __repr__(self) -> str:
return (f'PhysicalTokenBlock(device={self.device}, '
f'block_number={self.block_number}, '
f'num_hashed_tokens={self.num_hashed_tokens}, '
f'ref_count={self.ref_count}, '
f'last_accessed={self.last_accessed}, '
f'computed={self.computed})')
class BlockTable:
"""Holds a list of blocks with caching of their associated block_ids
"""
def __init__(self, blocks: Optional[List[PhysicalTokenBlock]] = None):
self._blocks: List[PhysicalTokenBlock] = []
self._block_ids: List[int] = []
if blocks is not None:
for block in blocks:
self.append(block)
def append(self, block: PhysicalTokenBlock):
self._blocks.append(block)
self._block_ids.append(block.block_number)
def __len__(self) -> int:
return len(self._blocks)
def __getitem__(self, key):
return self._blocks[key]
if TYPE_CHECKING:
def __iter__(self) -> Iterator[PhysicalTokenBlock]:
raise RuntimeError("Method should be automatically generated")
def __setitem__(self, key, value):
if isinstance(key, slice):
blocks = value
self._blocks[key] = blocks
self._block_ids[key] = [b.block_number for b in blocks]
else:
block = value
self._blocks[key] = block
self._block_ids[key] = block.block_number
def reset(self):
self._blocks = []
self._block_ids = []
def copy(self) -> "BlockTable":
return BlockTable(self._blocks)
def list(self) -> List[PhysicalTokenBlock]:
return self._blocks
def ids(self) -> List[int]:
return self._block_ids