基于torch.compile和gptfast代码风格实现ChatGLM模型推理加速_fastgpt代码分析-程序员宅基地

技术标签: NLP自然语言处理  大模型加速  # 大模型  深度学习  人工智能  

       

目录

一、ChatGLM模型代码重构迁移

二、推理的代码重构

三、效果分析对比

参考文章


         torch2.0发布以后模型训练和推理可以实现一行代码加速,试用之后发现效果并不明显。随后gptfast项目也发布,表明它确实是可以实现模型推理的加速,看来之前试用是打开方式不对。最近参考gptfast项目,实现了对ChatGLM模型推理的加速,主要的原理是借助torch.compile对模型推理过程中构建计算图,实现加速。本文的重点工作就是展示模型代码和推理逻辑的迁移实现,以及加速效果的对比,当然这个方案比VLLM和tensort-LLM肯定是差了点,这个不是本文的重点,后面有空了也把vllm和tensort-LLM也写写博客对比一下效率。

一、ChatGLM模型代码重构迁移

      这个工作是真的不是特别好做,需要对模型结构和模型输入输出非常熟悉,同时也要对gptfast项目迁移原则比较熟悉,才能比较快的迁移成功。核心原则是不能有tensor切片操作,同时kvcache这种也要写成固定的长度,计算过程中不断的去填充更新,同时还要放在模型的结构外侧作为一个参数传入,加速才有效果。还有一个点要注意注意力计算的实现,由于torch更新了scaled_dot_product_attention使得最大长度的定长的矩阵计算注意力,和之前动态逐步增加长度的值是一样的,这个是注意力计算中tensor切片改写的前提(验证过确实是一样的)。细节的地方需要注意kvcache的维度形状,解码过程中不同阶段(首次forward和kvcache存在后的)模型输入的full_attention_mask是不一样的。

整体结构

class TransformerGLM(nn.Module):
    def __init__(self, config, device) -> None:
        super().__init__()
        self.config = config

        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        rotary_dim = (
            128
        )
        self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
                                              dtype=config.torch_dtype)
        self.layers = nn.ModuleList(TransformerBlock(config, i, device) for i in range(config.num_layers))
        self.final_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                       dtype=config.torch_dtype)

        self.output_layer = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.seq_length = config.seq_length

    def forward(self, input_ids,
                position_ids: Optional[torch.Tensor] = None,
                attention_mask: Optional[torch.BoolTensor] = None,
                input_pos=None,
                is_input_mask=False,
                kv_caches=None
                ) -> Tensor:

        inputs_embeds = self.embedding(input_ids)
        inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
        rotary_pos_emb = rotary_pos_emb[position_ids]
        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

        presents = ()
        for i, layer in enumerate(self.layers):
            inputs_embeds, kv_cache = layer(inputs_embeds, rotary_pos_emb=rotary_pos_emb, input_pos=input_pos,
                                            attention_mask=attention_mask, kv_cache=kv_caches[i])
            presents = presents + (kv_cache,)
        hidden_states = self.final_layernorm(inputs_embeds)
        lm_logits = self.output_layer(hidden_states)
        lm_logits = lm_logits.transpose(0, 1).contiguous()
        return lm_logits, presents

注意模型的输入新增的有input_pos,模型解码token的位置,kv_caches;模型基本模块上没有变化,精简其中的一下预处理逻辑和分支,主要就是要让torch.compile()能完成计算图的构建。

kvcache模块

class KVCache(nn.Module):
    def __init__(self, max_batch_size, max_seq_length, dtype=torch.bfloat16):
        super().__init__()
        cache_shape = (2, max_batch_size, max_seq_length, 128)
        self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))

    def update(self, input_pos, k_val, v_val):
        # input_pos: S, k_val: [S, B, H, D]
        assert input_pos.shape[0] == k_val.shape[0]
        k_out = self.k_cache

        v_out = self.v_cache
        k_val = k_val.transpose(0, 2).contiguous()

        v_val = v_val.transpose(0, 2).contiguous()
        k_out[:, :, input_pos] = k_val.clone()
        v_out[:, :, input_pos] = v_val.clone()
        k_out = k_out.transpose(0, 2).contiguous()

        v_out = v_out.transpose(0, 2).contiguous()

        return k_out, v_out

模块中各个变量的维度信息都标注好了,作用就是kv缓存载体以及更新逻辑提供一个方法。

其他模块就不一一介绍了,注意selfattention中kvcache的更新

整个模型的代码如下:

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from math import gcd
from functools import reduce
import math


def find_multiple(n: int, *args: Tuple[int]) -> int:
    k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,))
    if n % k == 0:
        return n
    return n + k - (n % k)


class CoreAttention(torch.nn.Module):
    def __init__(self, config, layer_number):
        super(CoreAttention, self).__init__()
        self.config = config
        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32

        self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)

        projection_size = config.kv_channels * config.num_attention_heads

        self.hidden_size_per_partition = projection_size
        self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
        self.num_attention_heads_per_partition = config.num_attention_heads

        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        coeff = self.layer_number
        self.norm_factor *= coeff
        self.coeff = coeff
        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)

    def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
        query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
        if attention_mask is None:
            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                             is_causal=True)
        else:
            attention_mask = ~attention_mask
            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                             attention_mask)

        context_layer = context_layer.permute(2, 0, 1, 3)
        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
        context_layer = context_layer.reshape(*new_context_layer_shape)

        return context_layer


def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
    # x: [sq, b, np, hn]
    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
    rot_dim = rope_cache.shape[-2] * 2
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
    # truncate to support variable sizes
    rope_cache = rope_cache[:sq]
    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )
    x_out2 = x_out2.flatten(3)
    return torch.cat((x_out2, x_pass), dim=-1)


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, original_impl=False, device=None, dtype=None):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.dim = dim
        self.original_impl = original_impl

    def forward_impl(
            self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
    ):
        """Enhanced Transformer with Rotary Position Embedding.

        Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
        transformers/rope/__init__.py. MIT License:
        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
        """
        # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))

        # Create position indexes `[0, 1, ..., seq_len - 1]`
        seq_idx = torch.arange(seq_len, dtype=dtype, device=device)

        # Calculate the product of position index and $\theta_i$
        idx_theta = torch.outer(seq_idx, theta).float()

        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

        # this is to mimic the behaviour of complex32, else we will get different results
        if dtype in (torch.float16, torch.bfloat16, torch.int8):
            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
        return cache

    def forward(self, max_seq_len, offset=0):
        return self.forward_impl(
            max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
        )


class KVCache(nn.Module):
    def __init__(self, max_batch_size, max_seq_length, dtype=torch.bfloat16):
        super().__init__()
        cache_shape = (2, max_batch_size, max_seq_length, 128)
        self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))

    def update(self, input_pos, k_val, v_val):
        # input_pos: S, k_val: [S, B, H, D]
        assert input_pos.shape[0] == k_val.shape[0]
        k_out = self.k_cache

        v_out = self.v_cache
        k_val = k_val.transpose(0, 2).contiguous()

        v_val = v_val.transpose(0, 2).contiguous()
        k_out[:, :, input_pos] = k_val.clone()
        v_out[:, :, input_pos] = v_val.clone()
        k_out = k_out.transpose(0, 2).contiguous()

        v_out = v_out.transpose(0, 2).contiguous()

        return k_out, v_out


class RMSNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        return (self.weight * hidden_states).to(input_dtype)


class TransformerGLM(nn.Module):
    def __init__(self, config, device) -> None:
        super().__init__()
        self.config = config

        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        rotary_dim = (
            128
        )
        self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
                                              dtype=config.torch_dtype)
        self.layers = nn.ModuleList(TransformerBlock(config, i, device) for i in range(config.num_layers))
        self.final_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                       dtype=config.torch_dtype)

        self.output_layer = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.seq_length = config.seq_length

    def forward(self, input_ids,
                position_ids: Optional[torch.Tensor] = None,
                attention_mask: Optional[torch.BoolTensor] = None,
                input_pos=None,
                is_input_mask=False,
                kv_caches=None
                ) -> Tensor:

        inputs_embeds = self.embedding(input_ids)
        inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
        rotary_pos_emb = rotary_pos_emb[position_ids]
        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

        presents = ()
        for i, layer in enumerate(self.layers):
            inputs_embeds, kv_cache = layer(inputs_embeds, rotary_pos_emb=rotary_pos_emb, input_pos=input_pos,
                                            attention_mask=attention_mask, kv_cache=kv_caches[i])
            presents = presents + (kv_cache,)
        hidden_states = self.final_layernorm(inputs_embeds)
        lm_logits = self.output_layer(hidden_states)
        lm_logits = lm_logits.transpose(0, 1).contiguous()
        return lm_logits, presents


class MLP(torch.nn.Module):
    """MLP.
    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

    def __init__(self, config, device=None):
        super(MLP, self).__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
        self.dense_h_to_4h = nn.Linear(
            config.hidden_size,
            config.ffn_hidden_size * 2,
            bias=self.add_bias,
            device=device,
            # **_config_to_kwargs(config)
        )

        def swiglu(x):
            x = torch.chunk(x, 2, dim=-1)
            return F.silu(x[0]) * x[1]

        self.activation_func = swiglu

        # Project back to h.
        self.dense_4h_to_h = nn.Linear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=self.add_bias,
            device=device,
            # **_config_to_kwargs(config)
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output = self.dense_4h_to_h(intermediate_parallel)
        return output


class SelfAttention(torch.nn.Module):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(self, config, layer_number, device=None):
        super(SelfAttention, self).__init__()
        self.config = config
        self.layer_number = max(1, layer_number)

        self.projection_size = config.kv_channels * config.num_attention_heads

        # Per attention head and per partition values.
        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
        # 32
        self.num_attention_heads_per_partition = config.num_attention_heads

        self.multi_query_attention = config.multi_query_attention
        self.qkv_hidden_size = 3 * self.projection_size

        self.num_multi_query_groups_per_partition = config.multi_query_group_num
        self.qkv_hidden_size = (
                self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
        )
        self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
                                         bias=config.add_bias_linear or config.add_qkv_bias,
                                         # device=device, **_config_to_kwargs(config)
                                         )

        self.core_attention = CoreAttention(config, self.layer_number)

        # Output.
        self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
                               device=device,
                               # **_config_to_kwargs(config)
                               )

    def forward(
            self, hidden_states, rotary_pos_emb, input_pos, attention_mask=None, kv_cache=None
    ):
        # hidden_states: [sq, b, h]

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
        mixed_x_layer = self.query_key_value(hidden_states)

        (query_layer, key_layer, value_layer) = mixed_x_layer.split(
            [
                self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
                self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
            ],
            dim=-1,
        )

        query_layer = query_layer.view(
            query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
        )
        key_layer = key_layer.view(
            key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
        )
        value_layer = value_layer.view(
            value_layer.size()[:-1]
            + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
        )

        # apply relative positional encoding (rotary embedding)
        query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
        key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

        # 更新kvcache
        cache_k, cache_v = kv_cache
        cache_k[input_pos] = key_layer
        cache_v[input_pos] = value_layer
        key_layer = cache_k.clone()
        value_layer = cache_v.clone()
        kv_cache = (key_layer, value_layer)

        key_layer = key_layer.unsqueeze(-2)
        key_layer = key_layer.expand(
            -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
        )
        key_layer = key_layer.contiguous().view(
            key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
        )
        value_layer = value_layer.unsqueeze(-2)
        value_layer = value_layer.expand(
            -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
        )
        value_layer = value_layer.contiguous().view(
            value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
        )

        # ==================================
        # core attention computation
        # ==================================

        context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask=attention_mask)

        # =================
        # Output. [sq, b, h]
        # =================

        output = self.dense(context_layer)

        return output, kv_cache


class TransformerBlock(nn.Module):
    def __init__(self, config, layer_number, device) -> None:
        super().__init__()
        self.hidden_dropout = config.hidden_dropout
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                       dtype=config.torch_dtype)
        self.self_attention = SelfAttention(config, layer_number, device=device)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                                dtype=config.torch_dtype)
        self.mlp = MLP(config, device=device)

    def forward(self, hidden_states, rotary_pos_emb, input_pos, attention_mask=None, kv_cache=None):
        layernorm_output = self.input_layernorm(hidden_states)
        attention_output, kv_cache = self.self_attention(
            layernorm_output,
            rotary_pos_emb,
            input_pos,
            attention_mask=attention_mask,
            kv_cache=kv_cache
        )
        residual = hidden_states
        layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
        layernorm_input = residual + layernorm_input
        layernorm_output = self.post_attention_layernorm(layernorm_input)
        mlp_output = self.mlp(layernorm_output)
        residual = layernorm_input
        output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
        output = residual + output
        return output, kv_cache


class RMSNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        return (self.weight * hidden_states).to(input_dtype)


if __name__ == '__main__':
    import os

    os.environ['CUDA_VISIBLE_DEVICES'] = "1"
    from transformers import AutoConfig

    model_path = "./chatglm2-6b-merge"
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    model = TransformerGLM(config, device=None)
    for name, _ in model.named_parameters():
        print(name)

二、推理的代码重构

推理方法,也就是重写transformer模型中的generate这个方法,对于一次生成可以分为第一次解码forward阶段和余下的解码forward阶段。实现分别如下,只实现了greedy search 策略:

@torch.no_grad()
def first_decode_batch(model, input_ids, position_ids, input_pos, attention_mask, kv_caches):
    logits, kv_caches = model(input_ids=input_ids, position_ids=position_ids, input_pos=input_pos, is_input_mask=False,
                              attention_mask=attention_mask, kv_caches=kv_caches)
    logits = logits[:, -1:]
    next_tok = torch.argmax(logits, dim=-1)
    return next_tok, kv_caches


@torch.no_grad()
def decode_one_token_batch(model, input_ids, position_ids, input_pos, attention_mask, kv_caches):
    logits, kv_caches = model(input_ids, position_ids=position_ids, input_pos=input_pos, is_input_mask=True,
                              attention_mask=attention_mask, kv_caches=kv_caches)
    logits = logits[:, -1:]
    next_tok = torch.argmax(logits, dim=-1)
    return next_tok, kv_caches

主要是得到解码过程中模型输出的token和kv_caches。特别要注意的是不能把这两个方法封装到一个类中,然后再进行torch.compile这样模型能正确输出结果,但是推理速度没有提升的,也就是torch.compile并没有生效。

整体的generate逻辑,包含停止符号,模型的初始输入、kvcaches初始化以及attention_mask输入的变化、position_ids的输入变化,batch推理是padding的加入。

def generate_own_batch(model,
                 inputs,
                 sampling_kwargs,
                 eos_token,
                       max_seq_length, max_batch_size):
    device = inputs['input_ids'].device
    cache_shape = (max_seq_length, max_batch_size, 2, 128)
    dtype = torch.bfloat16
    kv_caches = [(torch.zeros(cache_shape, dtype=dtype).to(device), torch.zeros(cache_shape, dtype=dtype).to(device))
                 for _ in range(model.config.num_layers)]

    input_ids = inputs['input_ids']

    ori_input_ids = input_ids.clone()
    position_ids = inputs['position_ids']

    input_pos = []
    for _ in range(max_batch_size):
        pos = list(range(0,input_ids.shape[1]))
        input_pos.append(pos)
    input_pos = torch.tensor(input_pos, device=input_ids.device)

    # input_pos = torch.arange(0, input_ids.shape[1], device=input_ids.device)
    next_token, kv_caches = first_decode_batch(model, input_ids, position_ids, input_pos, None, kv_caches)

    full_attention_mask = torch.ones(max_batch_size, 1, 1, max_seq_length).to(device).bool()
    full_attention_mask[:, :, :, input_pos] = False

    # pading部分为true
    for i in range(full_attention_mask.shape[0]):
        for j in range(input_ids.shape[1]):
            if input_ids[i, j] == 0:
                full_attention_mask[i, :, :, j] = True

    input_ids = torch.cat((input_ids, next_token.clone()), dim=1)
    num_new_tokens = sampling_kwargs["max_length"]
    T = input_ids.size()[1]

    position_ids = position_ids[:,-1:]

    input_pos = []
    for _ in range(max_batch_size):
        pos = [T]
        input_pos.append(pos)
    input_pos = torch.tensor(input_pos, device=next_token.device, dtype=torch.long)

    # position_ids = torch.tensor([[T - 1]], device=next_token.device, dtype=torch.long)
    # input_pos = torch.tensor([T], device=input_ids.device, dtype=torch.long)

    for i in range(num_new_tokens):
        input_pos += 1
        # Actually better for Inductor to codegen attention here
        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
            full_attention_mask[:, :, :, input_pos] = False
            next_token, kv_caches = decode_one_token_batch(model, next_token, position_ids, input_pos,
                                                     full_attention_mask, kv_caches)

            input_ids = torch.cat((input_ids, next_token.clone()), dim=1)

            if (input_ids == eos_token).sum(dim=1).all():
                break

            position_ids += 1
            # token = next_token.tolist()
            # token = next_token.tolist()[0]
            # generated_tokens.append(token)
    return input_ids, ori_input_ids

推理核心逻辑

    model = TransformerGLM(config=config, device=None)
    checkpoint_dir = Path(model_path)
    model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
    converted_state_dict = model.state_dict()

    gen_kwargs = {"max_length": 200, "num_beams": 1,
                  "do_sample": False, "top_p": 0.8,
                  "temperature": 0.95
                  }
    device = "cuda:0"
    model_path = "./chatglm2-6b-merge"
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    eos_token = tokenizer.eos_token_id
......
# 编译加速
    global first_decode_batch, decode_one_token_batch
    decode_one_token_batch = torch.compile(decode_one_token_batch, mode="reduce-overhead", fullgraph=True)
    first_decode_batch = torch.compile(first_decode_batch, dynamic=True, fullgraph=True)

......
generate_own_batch(model, inputs, gen_kwargs, eos_token, max_seq_length, max_batch_size)

核心点在于把解码两阶段的函数使用torch.compile函数包裹一下,真实解码过程中就会进行解码加速。

三、效果分析对比

展示一下glm模型ori、compile、和compile+int8,bs=1,max_seq_length= 1000的情况下的推理速度和效果的对比,6B glm模型,模型输入prompt如下:

[
    "你好",
    "你是谁呀?",
    "你能做什么呀?",
    "你真厉害",
    "真棒呀",
    "再见了",
    "给我推荐一部电影",
    "你知道明天天气怎么样吗?"
]

ori原始transformer的推理效果如下:

使用compile后效果如下:

compile+int8效果如下:

可以看到相同的模型和相同的数据在bs=1下,原始模型推理速度31.7 tokens/s,compile的推理速度68.1 tokens/s,110.9 tokes/s;加速效果确实比较明显。

业务领域上的实验,这里也可以给一个结论,数据就不展示了,业务上生成的token数目每次推理大都在20 tokens以内,结果如下:

这次的分享就到这里为止了,这个迁移后的模型和推理在我们公司的服务端还有个问题,我们服务端采用的多进程异步来实现web服务的,这个gptfast的服务化的集成显示int8不生效,而且bs=1时候的推理加速并没有线下加速效果明显,具体原因一直没有弄明白,可能是其他进程占用服务器资源,导致torch.compile加速失效或者降低。

参考文章

gpt-fast实战

modeling_chatglm

gpt-fast项目源码

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/HUSTHY/article/details/136717509

智能推荐

5个超厉害的资源搜索网站,每一款都可以让你的资源满满!_最全资源搜索引擎-程序员宅基地

文章浏览阅读1.6w次,点赞8次,收藏41次。生活中我们无时不刻不都要在网站搜索资源,但就是缺少一个趁手的资源搜索网站,如果有一个比较好的资源搜索网站可以帮助我们节省一大半时间!今天小编在这里为大家分享5款超厉害的资源搜索网站,每一款都可以让你的资源丰富精彩!网盘传奇一款最有效的网盘资源搜索网站你还在为找网站里面的资源而烦恼找不到什么合适的工具而烦恼吗?这款网站传奇网站汇聚了4853w个资源,并且它每一天都会持续更新资源;..._最全资源搜索引擎

Book类的设计(Java)_6-1 book类的设计java-程序员宅基地

文章浏览阅读4.5k次,点赞5次,收藏18次。阅读测试程序,设计一个Book类。函数接口定义:class Book{}该类有 四个私有属性 分别是 书籍名称、 价格、 作者、 出版年份,以及相应的set 与get方法;该类有一个含有四个参数的构造方法,这四个参数依次是 书籍名称、 价格、 作者、 出版年份 。裁判测试程序样例:import java.util.*;public class Main { public static void main(String[] args) { List <Book>_6-1 book类的设计java

基于微信小程序的校园导航小程序设计与实现_校园导航微信小程序系统的设计与实现-程序员宅基地

文章浏览阅读613次,点赞28次,收藏27次。相比于以前的传统手工管理方式,智能化的管理方式可以大幅降低学校的运营人员成本,实现了校园导航的标准化、制度化、程序化的管理,有效地防止了校园导航的随意管理,提高了信息的处理速度和精确度,能够及时、准确地查询和修正建筑速看等信息。课题主要采用微信小程序、SpringBoot架构技术,前端以小程序页面呈现给学生,结合后台java语言使页面更加完善,后台使用MySQL数据库进行数据存储。微信小程序主要包括学生信息、校园简介、建筑速看、系统信息等功能,从而实现智能化的管理方式,提高工作效率。

有状态和无状态登录

传统上用户登陆状态会以 Session 的形式保存在服务器上,而 Session ID 则保存在前端的 Cookie 中;而使用 JWT 以后,用户的认证信息将会以 Token 的形式保存在前端,服务器不需要保存任何的用户状态,这也就是为什么 JWT 被称为无状态登陆的原因,无状态登陆最大的优势就是完美支持分布式部署,可以使用一个 Token 发送给不同的服务器,而所有的服务器都会返回同样的结果。有状态和无状态最大的区别就是服务端会不会保存客户端的信息。

九大角度全方位对比Android、iOS开发_ios 开发角度-程序员宅基地

文章浏览阅读784次。发表于10小时前| 2674次阅读| 来源TechCrunch| 19 条评论| 作者Jon EvansiOSAndroid应用开发产品编程语言JavaObjective-C摘要:即便Android市场份额已经超过80%,对于开发者来说,使用哪一个平台做开发仍然很难选择。本文从开发环境、配置、UX设计、语言、API、网络、分享、碎片化、发布等九个方面把Android和iOS_ios 开发角度

搜索引擎的发展历史

搜索引擎的发展历史可以追溯到20世纪90年代初,随着互联网的快速发展和信息量的急剧增加,人们开始感受到了获取和管理信息的挑战。这些阶段展示了搜索引擎在技术和商业模式上的不断演进,以满足用户对信息获取的不断增长的需求。

随便推点

控制对象的特性_控制对象特性-程序员宅基地

文章浏览阅读990次。对象特性是指控制对象的输出参数和输入参数之间的相互作用规律。放大系数K描述控制对象特性的静态特性参数。它的意义是:输出量的变化量和输入量的变化量之比。时间常数T当输入量发生变化后,所引起输出量变化的快慢。(动态参数) ..._控制对象特性

FRP搭建内网穿透(亲测有效)_locyanfrp-程序员宅基地

文章浏览阅读5.7w次,点赞50次,收藏276次。FRP搭建内网穿透1.概述:frp可以通过有公网IP的的服务器将内网的主机暴露给互联网,从而实现通过外网能直接访问到内网主机;frp有服务端和客户端,服务端需要装在有公网ip的服务器上,客户端装在内网主机上。2.简单的图解:3.准备工作:1.一个域名(www.test.xyz)2.一台有公网IP的服务器(阿里云、腾讯云等都行)3.一台内网主机4.下载frp,选择适合的版本下载解压如下:我这里服务器端和客户端都放在了/usr/local/frp/目录下4.执行命令# 服务器端给执_locyanfrp

UVA 12534 - Binary Matrix 2 (网络流‘最小费用最大流’ZKW)_uva12534-程序员宅基地

文章浏览阅读687次。题目:http://acm.hust.edu.cn/vjudge/contest/view.action?cid=93745#problem/A题意:给出r*c的01矩阵,可以翻转格子使得0表成1,1变成0,求出最小的步数使得每一行中1的个数相等,每一列中1的个数相等。思路:网络流。容量可以保证每一行和每一列的1的个数相等,费用可以算出最小步数。行向列建边,如果该格子是_uva12534

免费SSL证书_csdn alphassl免费申请-程序员宅基地

文章浏览阅读504次。1、Let's Encrypt 90天,支持泛域名2、Buypass:https://www.buypass.com/ssl/resources/go-ssl-technical-specification6个月,单域名3、AlwaysOnSLL:https://alwaysonssl.com/ 1年,单域名 可参考蜗牛(wn789)4、TrustAsia5、Alpha..._csdn alphassl免费申请

测试算法的性能(以选择排序为例)_算法性能测试-程序员宅基地

文章浏览阅读1.6k次。测试算法的性能 很多时候我们需要对算法的性能进行测试,最简单的方式是看算法在特定的数据集上的执行时间,简单的测试算法性能的函数实现见testSort()。【思想】:用clock_t计算某排序算法所需的时间,(endTime - startTime)/ CLOCKS_PER_SEC来表示执行了多少秒。【关于宏CLOCKS_PER_SEC】:以下摘自百度百科,“CLOCKS_PE_算法性能测试

Lane Detection_lanedetectionlite-程序员宅基地

文章浏览阅读1.2k次。fromhttps://towardsdatascience.com/finding-lane-lines-simple-pipeline-for-lane-detection-d02b62e7572bIdentifying lanes of the road is very common task that human driver performs. This is important ..._lanedetectionlite

推荐文章

热门文章

相关标签