Skip to content

Commit d839b94

Browse files
committed
adapted to ms2.7
1 parent f9a945c commit d839b94

File tree

3 files changed

+5
-6
lines changed

3 files changed

+5
-6
lines changed

examples/opensora_hpcai/opensora/models/layers/operation_selector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import mindspore as ms
22
from mindspore import mint, ops
3-
from mindspore.ops.function.array_func import chunk_ext, repeat_interleave_ext
3+
from mindspore.ops.function.array_func import repeat_interleave_ext
44

55
use_dynamic_ops = False
66

@@ -66,7 +66,7 @@ def get_chunk_op():
6666
if (mode == 0) and (not check_dynamic_mode()):
6767
return ops.chunk
6868
else:
69-
return chunk_ext
69+
return mint.chunk
7070

7171

7272
def get_split_op():

examples/opensora_hpcai/opensora/models/layers/rotary_embedding.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212
import numpy as np
1313

1414
import mindspore as ms
15-
from mindspore import Parameter, Tensor, dtype, nn, ops
16-
from mindspore.ops.function.array_func import chunk_ext as chunk
15+
from mindspore import Parameter, Tensor, dtype, mint, nn, ops
1716

1817
from .operation_selector import get_repeat_interleave_op
1918

2019

2120
def rotate_half(x: Tensor) -> Tensor:
2221
x = x.reshape(x.shape[:-1] + (-1, 2)) # ... (d r) -> ... d r, r = 2
23-
x1, x2 = chunk(x, 2, -1)
22+
x1, x2 = mint.chunk(x, 2, -1)
2423
x = ops.concat((-x2, x1), axis=-1)
2524
return x.reshape(x.shape[:-2] + (-1,)) # '... d r -> ... (d r)'
2625

examples/opensora_hpcai/scripts/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def main(args):
478478
videos.append(to_numpy(samples)[:, args.condition_frame_length if loop_i > 0 else 0 :])
479479
batch_time = time.time() - start_time
480480
logger.info(
481-
f"Batch time cost: {batch_time:.3f}s, sampling speed: {args.sampling_steps * ns / batch_time:.4f} step/s"
481+
f"Batch time cost: {batch_time:.3f}s, sampling speed: {batch_time / args.sampling_steps * ns:.4f} s/step"
482482
)
483483

484484
latents = np.concatenate(latents, axis=2)

0 commit comments

Comments
 (0)