Skip to content

Commit fa2e173

Browse files
14b conversion
1 parent 262ce19 commit fa2e173

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

scripts/convert_cosmos_to_diffusers.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,52 @@
2929
3030
Convert checkpoint
3131
```bash
32+
# pre-trained
3233
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
3334
3435
python scripts/convert_cosmos_to_diffusers.py \
3536
--transformer_type Cosmos-2.5-Predict-Base-2B \
3637
--transformer_ckpt_path $transformer_ckpt_path \
3738
--vae_type wan2.1 \
38-
--output_path converted/cosmos-p2.5-base-2b \
39+
--output_path converted/2b/d20b7120-df3e-4911-919d-db6e08bad31c \
40+
--save_pipeline
41+
42+
# post-trained
43+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt
44+
45+
python scripts/convert_cosmos_to_diffusers.py \
46+
--transformer_type Cosmos-2.5-Predict-Base-2B \
47+
--transformer_ckpt_path $transformer_ckpt_path \
48+
--vae_type wan2.1 \
49+
--output_path converted/2b/81edfebe-bd6a-4039-8c1d-737df1a790bf \
50+
--save_pipeline
51+
```
52+
53+
## 14B
54+
55+
```bash
56+
hf download nvidia/Cosmos-Predict2.5-14B
57+
```
58+
59+
```bash
60+
# pre-trained
61+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/03eb354f35eae0d6e0c1be3c9f94d8551e125570/base/pre-trained/54937b8c-29de-4f04-862c-e67b04ec41e8_ema_bf16.pt
62+
63+
python scripts/convert_cosmos_to_diffusers.py \
64+
--transformer_type Cosmos-2.5-Predict-Base-14B \
65+
--transformer_ckpt_path $transformer_ckpt_path \
66+
--vae_type wan2.1 \
67+
--output_path converted/14b/54937b8c-29de-4f04-862c-e67b04ec41e8/ \
68+
--save_pipeline
69+
70+
# post-trained
71+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/03eb354f35eae0d6e0c1be3c9f94d8551e125570/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt
72+
73+
python scripts/convert_cosmos_to_diffusers.py \
74+
--transformer_type Cosmos-2.5-Predict-Base-14B \
75+
--transformer_ckpt_path $transformer_ckpt_path \
76+
--vae_type wan2.1 \
77+
--output_path converted/14b/e21d2a49-4747-44c8-ba44-9f6f9243715f/ \
3978
--save_pipeline
4079
```
4180
@@ -298,6 +337,25 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
298337
"crossattn_proj_in_channels": 100352,
299338
"encoder_hidden_states_channels": 1024,
300339
},
340+
"Cosmos-2.5-Predict-Base-14B": {
341+
"in_channels": 16 + 1,
342+
"out_channels": 16,
343+
"num_attention_heads": 40,
344+
"attention_head_dim": 128,
345+
"num_layers": 36,
346+
"mlp_ratio": 4.0,
347+
"text_embed_dim": 1024,
348+
"adaln_lora_dim": 256,
349+
"max_size": (128, 240, 240),
350+
"patch_size": (1, 2, 2),
351+
"rope_scale": (1.0, 3.0, 3.0),
352+
"concat_padding_mask": True,
353+
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
354+
"extra_pos_embed_type": None,
355+
"use_crossattn_projection": True,
356+
"crossattn_proj_in_channels": 100352,
357+
"encoder_hidden_states_channels": 1024,
358+
},
301359
}
302360

303361
VAE_KEYS_RENAME_DICT = {

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def retrieve_latents(
133133
... num_frames=93,
134134
... generator=torch.Generator().manual_seed(1),
135135
... ).frames[0]
136-
>>> # export_to_video(video, "image2world.mp4", fps=16)
136+
>>> export_to_video(video, "image2world.mp4", fps=16)
137137
138138
>>> # Video2World: condition on an input clip and predict a 93-frame world video.
139139
>>> prompt = (

0 commit comments

Comments
 (0)