-
Notifications
You must be signed in to change notification settings - Fork 3
Context Parallelism #67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
From if parallel_dims.cp_enabled: # the following is necessary for CP w/ flex attention
from torch.distributed.tensor.experimental._attention import _set_cp_global_var, _DispatchMode, _cp_options
# set_rotate_method("alltoall") # alltoall or allgather (only allgather for flex)
_set_cp_global_var("cp_shard_dim", 2)
# _cp_options.enable_load_balance = True # no load balancing for flex
torch.distributed.tensor.experimental._attention._dispatch_mode = (
_DispatchMode.TORCH_FUNCTION
)
|
|
Problems with with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
return F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=scale)torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(1, 16, 4096, 192), dtype=torch.bfloat16,
grad_fn=<TransposeBackward0>), FakeTensor(..., device='cuda:0', size=(1, 16, 4096, 192), dtype=torch.bfloat16,
grad_fn=<TransposeBackward0>), FakeTensor(..., device='cuda:0', size=(1, 16, 4096, 128), dtype=torch.bfloat16,
grad_fn=<TransposeBackward0>)), **{'is_causal': True, 'scale': 0.07216878364870322}): got RuntimeError('No available kernel. Aborting execution.') |
|
@rlrs Context parallelism now runs for gemma and llama |
New dcp script related to model where yarn has been used to extend the context length
|
I have now also included the related to YaRN in this PR, see d36078d
|
|
All the CMDs below run as expected. Anything else we should test before merging this? @rlrs Training Comparing a Maester DCP checkpoint against a Hugging Face model python compare_models.py \
--job-config jobs/munin-32k/config.json \
--checkpoint-dir jobs/munin-32k/checkpoints/step-1000 \
--hf-model oliverkinch/munin-32k-step-1000 \
--num-prompts 0 \
--dataset data/wiki-expanded-hf \
--dataset-samples 4 \
--dataset-max-length 512Giving output as: PROMPT 0: '# Frankrig\n\nFrankrig (fransk: "France"), officielt Den Franske Republik (fransk:'
Tokenized length: 512
Logit max abs diff: 7.428315e+00
Logit mean abs diff: 1.034052e-01
HF loss: 0.775544 (ppl=2.172)
Maester loss: 0.771280 (ppl=2.163)Are these differences acceptable? YaRN convert python -u scripts/convert/llama/from_dcp_yarn.py \
jobs/munin-32k/checkpoints \
/tmp/munin-open-7b-pt-export \
--name step-1000 \
--base danish-foundation-models/munin-open-7b-pt |

Implements CP for non MoE models. Implementing CP for MoEs will be in a separate PR.
Fix #31.
#38 will be redundant given this PR.