diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 00890a49b9be..1cda0283d2d0 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -13,7 +13,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.tuned_gemm import tgemm +if (torch.version.hip is not None): + from vllm.model_executor.layers.tuned_gemm import tgemm from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -90,13 +91,18 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: weight = layer.weight + mm_result = None + if (torch.version.hip is not None): + mm_result = tgemm.mm(x, weight) + else: + mm_result = F.linear(x, weight) if self.separate_bias_add: if bias is not None: - return tgemm.mm(x, weight) + bias - return tgemm.mm(x, weight) + return mm_result + bias + return mm_result elif bias is not None: return F.linear(x, weight, bias) - return tgemm.mm(x, weight) + return mm_result class LinearBase(torch.nn.Module): diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index d81c00437d75..209019b6d87d 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -6,7 +6,8 @@ import torch.nn as nn from vllm.distributed import tensor_model_parallel_gather -from vllm.model_executor.layers.tuned_gemm import tgemm +if (torch.version.hip is not None): + from vllm.model_executor.layers.tuned_gemm import tgemm from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -63,7 +64,11 @@ def forward( def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: # Get the logits for the next tokens. - logits = tgemm.mm(hidden_states, embedding) + logits = None + if (torch.version.hip is not None): + logits = tgemm.mm(hidden_states, embedding) + else: + logits = F.linear(hidden_states, embedding) if embedding_bias is not None: logits += embedding_bias logits = tensor_model_parallel_gather(logits)