Conversation
Summary of ChangesHello @hiworldwzj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! 此拉取请求旨在通过引入 Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
本次 PR 引入了一种新的 FP8 量化方法 triton-fp8w8a8g128,其中权重采用 per-channel 量化,激活采用 per-group 量化。同时,也添加了对应的 g64 版本。代码整体实现良好,包含新的 Triton 内核、量化方法实现以及相应的文档和命令行参数更新。我提出了一些建议,主要包括:
- 优化
weight_quant函数中处理3D张量的性能。 - 完善文档和命令行帮助信息中关于
triton-fp8w8a8g64的描述,使其更加清晰准确。
这些修改将有助于提升代码性能和可维护性。
| if x.dim() == 3: | ||
| y_quant = torch.empty((x.shape[0], x.shape[1], x.shape[2]), dtype=torch.float8_e4m3fn, device=x.device) | ||
| s_scales = torch.empty((x.shape[0], x.shape[1], 1), dtype=torch.float32, device=x.device) | ||
| for i in range(x.shape[0]): | ||
| y_quant[i], s_scales[i] = mm_weight_quant(x[i]) | ||
| return y_quant, s_scales | ||
| else: | ||
| y_quant, s_scales = mm_weight_quant(x) | ||
| return y_quant, s_scales |
There was a problem hiding this comment.
处理3D张量的 for 循环会为每个批次项(batch item)单独启动一个CUDA内核,这在批次较大时效率低下。
为了提高性能,建议将3D张量重塑(reshape)为2D,进行一次批处理量化,然后再将结果重塑回3D。这样可以大大减少内核启动的开销。
| if x.dim() == 3: | |
| y_quant = torch.empty((x.shape[0], x.shape[1], x.shape[2]), dtype=torch.float8_e4m3fn, device=x.device) | |
| s_scales = torch.empty((x.shape[0], x.shape[1], 1), dtype=torch.float32, device=x.device) | |
| for i in range(x.shape[0]): | |
| y_quant[i], s_scales[i] = mm_weight_quant(x[i]) | |
| return y_quant, s_scales | |
| else: | |
| y_quant, s_scales = mm_weight_quant(x) | |
| return y_quant, s_scales | |
| if x.dim() == 3: | |
| b, m, n = x.shape | |
| y_quant, s_scales = mm_weight_quant(x.view(-1, n)) | |
| return y_quant.view(b, m, n), s_scales.view(b, m, 1) | |
| else: | |
| y_quant, s_scales = mm_weight_quant(x) | |
| return y_quant, s_scales |
| * ``deepgemm-fp8w8a8-b128`` | ||
| * ``triton-fp8w8a8-block128`` | ||
| * ``triton-fp8w8a8g128``: 权重 per-channel 量化和激活 per-group 128 量化 | ||
| * ``triton-fp8w8a8g64``: 权重 per-channel 量化, group size 64 |
| * ``deepgemm-fp8w8a8-b128`` | ||
| * ``triton-fp8w8a8-block128`` | ||
| * ``triton-fp8w8a8g128``: weight per-channel quant and activation per-group 128 quant | ||
| * ``triton-fp8w8a8g64``: weight per-channel quantization with group size 64 |
There was a problem hiding this comment.
| help="""Quantization method: vllm-w8a8 | vllm-fp8w8a8 | vllm-fp8w8a8-b128 | ||
| | deepgemm-fp8w8a8-b128 | triton-fp8w8a8-block128 | awq | awq_marlin""", | ||
| | deepgemm-fp8w8a8-b128 | triton-fp8w8a8-block128 | awq | awq_marlin | | ||
| | triton-fp8w8a8g128 (weight perchannel quant and act per group quant) | | ||
| triton-fp8w8a8g64 (weight perchannel quantization with group size 64)""", |
There was a problem hiding this comment.
triton-fp8w8a8g64 的帮助信息可以更完整,明确指出激活(activation)也是按组量化的,与 triton-fp8w8a8g128 的描述保持一致,以提高清晰度。
| help="""Quantization method: vllm-w8a8 | vllm-fp8w8a8 | vllm-fp8w8a8-b128 | |
| | deepgemm-fp8w8a8-b128 | triton-fp8w8a8-block128 | awq | awq_marlin""", | |
| | deepgemm-fp8w8a8-b128 | triton-fp8w8a8-block128 | awq | awq_marlin | | |
| | triton-fp8w8a8g128 (weight perchannel quant and act per group quant) | | |
| triton-fp8w8a8g64 (weight perchannel quantization with group size 64)""", | |
| help="""Quantization method: vllm-w8a8 | vllm-fp8w8a8 | vllm-fp8w8a8-b128 | |
| | deepgemm-fp8w8a8-b128 | triton-fp8w8a8-block128 | awq | awq_marlin | | |
| | triton-fp8w8a8g128 (weight perchannel quant and act per group quant) | | |
| triton-fp8w8a8g64 (weight perchannel quant and act per group 64 quant)""" |
添加一种新的量化方法,激活使用per group 量化, 权重使用per channel 量化。