ezpz.models.llamaΒΆ
Attention
ΒΆ
Bases: Module
Multi-head attention module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_args
|
ModelArgs
|
Model configuration arguments. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
n_kv_heads |
int
|
Number of key and value heads. |
n_heads |
int
|
Number of query heads. |
n_local_kv_heads |
int
|
Number of local key and value heads. |
n_rep |
int
|
Number of repetitions for local heads. |
head_dim |
int
|
Dimension size of each attention head. |
wq |
Linear
|
Linear transformation for queries. |
wk |
Linear
|
Linear transformation for keys. |
wv |
Linear
|
Linear transformation for values. |
wo |
Linear
|
Linear transformation for output. |
Source code in src/ezpz/models/llama.py
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 | |
forward(x, freqs_cis)
ΒΆ
Forward pass of the attention module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor. |
required |
freqs_cis
|
Tensor
|
Precomputed frequency tensor. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Output tensor after attention. |
Source code in src/ezpz/models/llama.py
FeedForward
ΒΆ
Bases: Module
FeedForward module
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Input dimension. |
required |
hidden_dim
|
int
|
Hidden dimension of the feedforward layer. |
required |
multiple_of
|
int
|
Value to ensure hidden dimension is a multiple of this value. |
required |
ffn_dim_multiplier
|
Optional[float]
|
Custom multiplier for hidden dimension. Defaults to None. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
w1 |
Linear
|
Linear transformation for the first layer. |
w2 |
Linear
|
Linear transformation for the second layer. |
w3 |
Linear
|
Linear transformation for the third layer. |
Source code in src/ezpz/models/llama.py
RMSNorm
ΒΆ
Bases: Module
Initialize the RMSNorm normalization layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
The dimension of the input tensor. |
required |
eps
|
float
|
A small value added to the denominator for numerical stability. Default is 1e-6. |
1e-06
|
Attributes:
| Name | Type | Description |
|---|---|---|
eps |
float
|
A small value added to the denominator for numerical stability. |
weight |
Parameter
|
Learnable scaling parameter. |
Source code in src/ezpz/models/llama.py
Transformer
ΒΆ
Bases: Module
Transformer Module
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_args
|
ModelArgs
|
Model configuration arguments. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
model_args |
ModelArgs
|
Model configuration arguments. |
vocab_size |
int
|
Vocabulary size. |
n_layers |
int
|
Number of layers in the model. |
tok_embeddings |
ParallelEmbedding
|
Token embeddings. |
layers |
ModuleList
|
List of Transformer blocks. |
norm |
RMSNorm
|
Layer normalization for the model output. |
output |
ColumnParallelLinear
|
Linear layer for final output. |
freqs_cis |
Tensor
|
Precomputed cosine and sine frequencies. |
Source code in src/ezpz/models/llama.py
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 | |
forward(tokens)
ΒΆ
Perform a forward pass through the Transformer model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tokens
|
Tensor
|
Input token indices. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Output logits after applying the Transformer model. |
Source code in src/ezpz/models/llama.py
from_model_args(model_args)
classmethod
ΒΆ
Initialize a Transformer model from a ModelArgs object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_args
|
ModelArgs
|
Model configuration arguments. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
Transformer |
Transformer
|
Transformer model. |
Source code in src/ezpz/models/llama.py
init_weights()
ΒΆ
[Note: On init_weights vs. reset_parameters]
Modules may define reset_parameters to initialize parameter values.
reset_parameters is meant to only initialize directly owned
parameters/buffers, not those of their child modules, and it can be
used to give the initial values for these tensors.
Separately, users may want custom initialization for their modules,
different from that in reset_parameters. For this, we define
init_weights. We only call it in the constructor of this
Transformer root module to avoid reinitializing tensors.
Source code in src/ezpz/models/llama.py
TransformerBlock
ΒΆ
Bases: Module
TransformerBlock Module
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_id
|
int
|
Identifier for the layer. |
required |
model_args
|
ModelArgs
|
Model configuration arguments. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
n_heads |
int
|
Number of attention heads. |
dim |
int
|
Dimension size of the model. |
head_dim |
int
|
Dimension size of each attention head. |
attention |
Attention
|
Attention module. |
feed_forward |
FeedForward
|
FeedForward module. |
layer_id |
int
|
Identifier for the layer. |
attention_norm |
RMSNorm
|
Layer normalization for attention output. |
ffn_norm |
RMSNorm
|
Layer normalization for feedforward output. |
Source code in src/ezpz/models/llama.py
forward(x, freqs_cis)
ΒΆ
Perform a forward pass through the TransformerBlock.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor. |
required |
freqs_cis
|
Tensor
|
Precomputed cosine and sine frequencies. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Output tensor after applying attention and feedforward layers. |
Source code in src/ezpz/models/llama.py
apply_rotary_emb(xq, xk, freqs_cis)
ΒΆ
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
xq
|
Tensor
|
Query tensor to apply rotary embeddings. |
required |
xk
|
Tensor
|
Key tensor to apply rotary embeddings. |
required |
freqs_cis
|
Tensor
|
Precomputed frequency tensor for complex exponentials. |
required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor]
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. |
Source code in src/ezpz/models/llama.py
precompute_freqs_cis(dim, end, theta=10000.0)
ΒΆ
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Dimension of the frequency tensor. |
required |
end
|
int
|
End index for precomputing frequencies. |
required |
theta
|
float
|
Scaling factor for frequency computation. Defaults to 10000.0. |
10000.0
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Precomputed frequency tensor with complex exponentials. |
Source code in src/ezpz/models/llama.py
repeat_kv(x, n_rep)
ΒΆ
torch.repeat_interleave(x, dim=2, repeats=n_rep)
Source code in src/ezpz/models/llama.py
reshape_for_broadcast(freqs_cis, x)
ΒΆ
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
freqs_cis
|
Tensor
|
Frequency tensor to be reshaped. |
required |
x
|
Tensor
|
Target tensor for broadcasting compatibility. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Reshaped frequency tensor. |