Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 56 additions & 29 deletions examples/llm_eval/lm_eval_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@

import datasets
from lm_eval import utils
from lm_eval.__main__ import cli_evaluate, parse_eval_args, setup_parser
from packaging.version import Version

if not version("lm_eval").startswith("0.4.8"):
warnings.warn(
f"lm_eval_hf.py is tested with lm-eval 0.4.8; found {version('lm_eval')}. "
"Later versions may have incompatible API changes."
)
if Version(version("lm_eval")) < Version("0.4.10"):
raise ImportError(f"lm_eval_hf.py requires lm-eval >= 0.4.10; found {version('lm_eval')}.")

from lm_eval._cli import HarnessCLI
from lm_eval.api.model import T
from lm_eval.models.huggingface import HFLM
from lm_eval.utils import setup_logging
from quantization_utils import quantize_model
from sparse_attention_utils import sparsify_model

Expand Down Expand Up @@ -160,9 +160,24 @@ def create_from_arg_string(
HFLM.create_from_arg_string = classmethod(create_from_arg_string)


def setup_parser_with_modelopt_args():
"""Extend the lm-eval argument parser with ModelOpt quantization and sparsity options."""
parser = setup_parser()
# ModelOpt-specific args that we add to lm-eval's parser. After parsing, these are
# moved out of the argparse namespace and into args.model_args so they reach
# HFLM.create_from_arg_obj (and so lm-eval's own arg validation doesn't reject them).
_MODELOPT_ARG_KEYS = (
"quant_cfg",
"calib_batch_size",
"calib_size",
"auto_quantize_bits",
"auto_quantize_method",
"auto_quantize_score_size",
"auto_quantize_checkpoint",
"compress",
"sparse_cfg",
)


def _add_modelopt_args(parser):
"""Extend an lm-eval argument parser with ModelOpt quantization and sparsity options."""
parser.add_argument(
"--quant_cfg",
type=str,
Expand Down Expand Up @@ -221,33 +236,45 @@ def setup_parser_with_modelopt_args():
type=str,
help="Sparse attention configuration (e.g., SKIP_SOFTMAX_DEFAULT, SKIP_SOFTMAX_CALIB)",
)
return parser


if __name__ == "__main__":
parser = setup_parser_with_modelopt_args()
args = parse_eval_args(parser)
model_args = utils.simple_parse_args_string(args.model_args)
def _inject_modelopt_args_into_model_args(args):
"""Move ModelOpt args from the argparse namespace into args.model_args.

args.model_args is a dict (parsed by lm-eval's MergeDictAction). The ModelOpt
keys must be removed from the namespace so EvaluatorConfig.from_cli doesn't
reject them as unknown kwargs.
"""
model_args = dict(args.model_args) if args.model_args else {}
Comment thread
kevalmorabia97 marked this conversation as resolved.

if args.trust_remote_code:
if getattr(args, "trust_remote_code", False):
# Propagate the user-provided --trust_remote_code flag (not hardcoded).
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
model_args["trust_remote_code"] = True
args.trust_remote_code = None
Comment thread
kevalmorabia97 marked this conversation as resolved.

model_args.update(
{
"quant_cfg": args.quant_cfg,
"auto_quantize_bits": args.auto_quantize_bits,
"auto_quantize_method": args.auto_quantize_method,
"auto_quantize_score_size": args.auto_quantize_score_size,
"auto_quantize_checkpoint": args.auto_quantize_checkpoint,
"calib_batch_size": args.calib_batch_size,
"calib_size": args.calib_size,
"compress": args.compress,
"sparse_cfg": args.sparse_cfg,
}
)
for key in _MODELOPT_ARG_KEYS:
if hasattr(args, key):
model_args[key] = getattr(args, key)
Comment thread
kevalmorabia97 marked this conversation as resolved.
delattr(args, key)
Comment thread
kevalmorabia97 marked this conversation as resolved.

args.model_args = model_args

cli_evaluate(args)

if __name__ == "__main__":
setup_logging()
cli = HarnessCLI()
# The `run` subcommand owns the model/task arguments; extend that parser.
# `_subparsers` is private API; guard so a future lm-eval refactor surfaces a
# clear error instead of an opaque AttributeError.
try:
run_parser = cli._subparsers.choices["run"]
Comment thread
kevalmorabia97 marked this conversation as resolved.
except (AttributeError, KeyError) as e:
raise RuntimeError(
"Cannot locate lm-eval's `run` subparser; the HarnessCLI internals may "
f"have changed. Installed lm-eval version: {version('lm_eval')}."
) from e
_add_modelopt_args(run_parser)
args = cli.parse_args()
_inject_modelopt_args_into_model_args(args)
cli.execute(args)
2 changes: 1 addition & 1 deletion examples/llm_eval/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
fire>=0.5.0
lm_eval[api,ifeval]==0.4.8
lm_eval[api,ifeval]>=0.4.10
Comment thread
kevalmorabia97 marked this conversation as resolved.
peft>=0.5.0
rwkv>=0.7.3
torchvision
1 change: 0 additions & 1 deletion examples/puzzletron/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
lm-eval==0.4.8
Comment thread
kevalmorabia97 marked this conversation as resolved.
math-verify
ray
# Likely works for transformers v5 also, but we need to test it
Expand Down
30 changes: 26 additions & 4 deletions tests/examples/llm_eval/test_llm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,38 @@

import subprocess

from _test_utils.examples.models import TINY_LLAMA_PATH
from _test_utils.examples.run_command import run_llm_ptq_command
from _test_utils.examples.run_command import (
extend_cmd_parts,
run_example_command,
run_llm_ptq_command,
)
from _test_utils.torch.misc import minimum_sm
from _test_utils.torch.transformers_models import create_tiny_qwen3_dir


def test_lm_eval_hf(tmp_path):
model_dir = create_tiny_qwen3_dir(tmp_path, with_tokenizer=True)

cmd_parts = extend_cmd_parts(
["python", "lm_eval_hf.py"],
model="hf",
model_args=f"pretrained={model_dir}",
tasks="mmlu",
num_fewshot=5,
limit=0.1,
batch_size=8,
)
run_example_command(cmd_parts, "llm_eval")
Comment thread
kevalmorabia97 marked this conversation as resolved.


@minimum_sm(89)
def test_llama_eval_fp8():
def test_qwen3_eval_fp8(tmp_path):
# Bump max_position_embeddings: TRT-LLM serve rejects prompts longer than
# max_seq_len, and the default (32) is shorter than even simple MMLU prompts.
model_dir = create_tiny_qwen3_dir(tmp_path, with_tokenizer=True, max_position_embeddings=2048)
try:
run_llm_ptq_command(
model=TINY_LLAMA_PATH,
model=str(model_dir),
quant="fp8",
tasks="mmlu,lm_eval,simple_eval",
calib=64,
Expand Down
Loading