Skip to content

Commit aea63d7

Browse files
authored
fix: avoid blocking async functions (#943)
* fix: avoid blocking async functions * test: add more tests for transform flow
1 parent d567c99 commit aea63d7

File tree

2 files changed

+76
-22
lines changed

2 files changed

+76
-22
lines changed

python/cocoindex/flow.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,23 @@ def _transform_helper(
120120
else:
121121
raise ValueError("transform() can only be called on a CocoIndex function")
122122

123-
return _create_data_slice(
124-
flow_builder_state,
125-
lambda target_scope, name: flow_builder_state.engine_flow_builder.transform(
123+
def _create_data_slice_inner(
124+
target_scope: _engine.DataScopeRef | None, name: str | None
125+
) -> _engine.DataSlice:
126+
result = flow_builder_state.engine_flow_builder.transform(
126127
kind,
127128
dump_engine_object(spec),
128129
transform_args,
129130
target_scope,
130131
flow_builder_state.field_name_builder.build_name(
131132
name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_"
132133
),
133-
),
134+
)
135+
return result
136+
137+
return _create_data_slice(
138+
flow_builder_state,
139+
_create_data_slice_inner,
134140
name,
135141
)
136142

@@ -166,6 +172,7 @@ def __init__(
166172
def engine_data_slice(self) -> _engine.DataSlice:
167173
"""
168174
Get the internal DataSlice.
175+
This can be blocking.
169176
"""
170177
if self._lazy_lock is None:
171178
if self._data_slice is None:
@@ -179,6 +186,13 @@ def engine_data_slice(self) -> _engine.DataSlice:
179186
self._data_slice = self._data_slice_creator(None)
180187
return self._data_slice
181188

189+
async def engine_data_slice_async(self) -> _engine.DataSlice:
190+
"""
191+
Get the internal DataSlice.
192+
This can be blocking.
193+
"""
194+
return await asyncio.to_thread(lambda: self.engine_data_slice)
195+
182196
def attach_to_scope(self, scope: _engine.DataScopeRef, field_name: str) -> None:
183197
"""
184198
Attach the current data slice (if not yet attached) to the given scope.
@@ -795,9 +809,8 @@ async def setup_async(self, report_to_stdout: bool = False) -> None:
795809
"""
796810
Setup persistent backends of the flow. The async version.
797811
"""
798-
await make_setup_bundle([self]).describe_and_apply_async(
799-
report_to_stdout=report_to_stdout
800-
)
812+
bundle = await make_setup_bundle_async([self])
813+
await bundle.describe_and_apply_async(report_to_stdout=report_to_stdout)
801814

802815
def drop(self, report_to_stdout: bool = False) -> None:
803816
"""
@@ -814,9 +827,8 @@ async def drop_async(self, report_to_stdout: bool = False) -> None:
814827
"""
815828
Drop persistent backends of the flow. The async version.
816829
"""
817-
await make_drop_bundle([self]).describe_and_apply_async(
818-
report_to_stdout=report_to_stdout
819-
)
830+
bundle = await make_drop_bundle_async([self])
831+
await bundle.describe_and_apply_async(report_to_stdout=report_to_stdout)
820832

821833
def close(self) -> None:
822834
"""
@@ -1071,19 +1083,16 @@ async def _build_flow_info_async(self) -> TransformFlowInfo:
10711083
_DataSliceState(flow_builder_state, engine_ds)
10721084
)
10731085

1074-
output = self._flow_fn(**kwargs)
1075-
flow_builder_state.engine_flow_builder.set_direct_output(
1076-
_data_slice_state(output).engine_data_slice
1077-
)
1086+
output = await asyncio.to_thread(lambda: self._flow_fn(**kwargs))
1087+
output_data_slice = await _data_slice_state(output).engine_data_slice_async()
1088+
1089+
flow_builder_state.engine_flow_builder.set_direct_output(output_data_slice)
10781090
engine_flow = (
10791091
await flow_builder_state.engine_flow_builder.build_transient_flow_async(
10801092
execution_context.event_loop
10811093
)
10821094
)
1083-
1084-
engine_return_type = (
1085-
_data_slice_state(output).engine_data_slice.data_type().schema()
1086-
)
1095+
engine_return_type = output_data_slice.data_type().schema()
10871096
python_return_type: type[T] | None = _get_data_slice_annotation_type(
10881097
inspect.signature(self._flow_fn).return_annotation
10891098
)
@@ -1142,28 +1151,42 @@ def _transform_flow_wrapper(fn: Callable[..., DataSlice[T]]) -> TransformFlow[T]
11421151
return _transform_flow_wrapper
11431152

11441153

1145-
def make_setup_bundle(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
1154+
async def make_setup_bundle_async(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
11461155
"""
11471156
Make a bundle to setup flows with the given names.
11481157
"""
11491158
full_names = []
11501159
for fl in flow_iter:
1151-
fl.internal_flow()
1160+
await fl.internal_flow_async()
11521161
full_names.append(fl.full_name)
11531162
return SetupChangeBundle(_engine.make_setup_bundle(full_names))
11541163

11551164

1156-
def make_drop_bundle(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
1165+
def make_setup_bundle(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
1166+
"""
1167+
Make a bundle to setup flows with the given names.
1168+
"""
1169+
return execution_context.run(make_setup_bundle_async(flow_iter))
1170+
1171+
1172+
async def make_drop_bundle_async(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
11571173
"""
11581174
Make a bundle to drop flows with the given names.
11591175
"""
11601176
full_names = []
11611177
for fl in flow_iter:
1162-
fl.internal_flow()
1178+
await fl.internal_flow_async()
11631179
full_names.append(fl.full_name)
11641180
return SetupChangeBundle(_engine.make_drop_bundle(full_names))
11651181

11661182

1183+
def make_drop_bundle(flow_iter: Iterable[Flow]) -> SetupChangeBundle:
1184+
"""
1185+
Make a bundle to drop flows with the given names.
1186+
"""
1187+
return execution_context.run(make_drop_bundle_async(flow_iter))
1188+
1189+
11671190
def setup_all_flows(report_to_stdout: bool = False) -> None:
11681191
"""
11691192
Setup all flows registered in the current process.

python/cocoindex/tests/test_transform_flow.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,25 @@ def __call__(self, text: str) -> str:
166166
return f"{text}{self.spec.suffix}"
167167

168168

169+
class GpuAppendSuffixWithAnalyzePrepare(cocoindex.op.FunctionSpec):
170+
suffix: str
171+
172+
173+
@cocoindex.op.executor_class(gpu=True)
174+
class GpuAppendSuffixWithAnalyzePrepareExecutor:
175+
spec: GpuAppendSuffixWithAnalyzePrepare
176+
suffix: str
177+
178+
def analyze(self) -> Any:
179+
return str
180+
181+
def prepare(self) -> None:
182+
self.suffix = self.spec.suffix
183+
184+
def __call__(self, text: str) -> str:
185+
return f"{text}{self.suffix}"
186+
187+
169188
def test_gpu_function() -> None:
170189
@cocoindex.transform_flow()
171190
def transform_flow(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[str]:
@@ -174,3 +193,15 @@ def transform_flow(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[str]:
174193
result = transform_flow.eval("Hello")
175194
expected = "Hello world!"
176195
assert result == expected, f"Expected {expected}, got {result}"
196+
197+
@cocoindex.transform_flow()
198+
def transform_flow_with_analyze_prepare(
199+
text: cocoindex.DataSlice[str],
200+
) -> cocoindex.DataSlice[str]:
201+
return text.transform(gpu_append_world).transform(
202+
GpuAppendSuffixWithAnalyzePrepare(suffix="!!")
203+
)
204+
205+
result = transform_flow_with_analyze_prepare.eval("Hello")
206+
expected = "Hello world!!"
207+
assert result == expected, f"Expected {expected}, got {result}"

0 commit comments

Comments
 (0)