Skip to content

Commit 5f6ea05

Browse files
committed
Fix custom select statement
xzkostyan#233
1 parent d332ccb commit 5f6ea05

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,34 @@
88

99

1010
class ClickHouseSQLCompiler(compiler.SQLCompiler):
11+
CUSTOM_SELECT_ATTRS = [
12+
'_with_cube', '_with_rollup', '_with_totals', '_final_clause',
13+
'_sample_clause', '_limit_by_clause', '_array_join'
14+
]
15+
16+
def visit_select(
17+
self,
18+
select_stmt,
19+
**kwargs,
20+
):
21+
orig_compile_state_factory = select_stmt._compile_state_factory
22+
23+
def compile_state_factory(self, *args, **kwargs):
24+
result = orig_compile_state_factory(self, *args, **kwargs)
25+
26+
if hasattr(result, 'select_statement'):
27+
# Fix missed attributes
28+
for attr in ClickHouseSQLCompiler.CUSTOM_SELECT_ATTRS:
29+
val = getattr(result.select_statement, attr, None)
30+
31+
if val is not None:
32+
setattr(result.statement, attr, val)
33+
34+
return result
35+
36+
select_stmt._compile_state_factory = compile_state_factory
37+
return super().visit_select(select_stmt=select_stmt, **kwargs)
38+
1139
def visit_mod_binary(self, binary, operator, **kw):
1240
return self.process(binary.left, **kw) + ' %% ' + \
1341
self.process(binary.right, **kw)

clickhouse_sqlalchemy/orm/query.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@ class Query(BaseQuery):
1919
_limit_by = None
2020
_array_join = None
2121

22+
def _statement_20(self, for_statement=False, use_legacy_query_style=True):
23+
orig_smt = super(Query, self)._statement_20(
24+
for_statement=for_statement,
25+
use_legacy_query_style=use_legacy_query_style
26+
)
27+
28+
orig_smt._with_cube = self._with_cube
29+
orig_smt._with_rollup = self._with_rollup
30+
orig_smt._with_totals = self._with_totals
31+
orig_smt._final_clause = self._final
32+
orig_smt._sample_clause = sample_clause(self._sample)
33+
orig_smt._limit_by_clause = self._limit_by
34+
orig_smt._array_join = self._array_join
35+
36+
return orig_smt
37+
2238
def _compile_context(self, *args, **kwargs):
2339
context = super(Query, self)._compile_context(*args, **kwargs)
2440
query = context.query

0 commit comments

Comments
 (0)