Skip to content

Commit 407bbc3

Browse files
committed
Fix custom select statement
xzkostyan#233
1 parent 2de8c5a commit 407bbc3

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,34 @@
1010

1111

1212
class ClickHouseSQLCompiler(compiler.SQLCompiler):
13+
CUSTOM_SELECT_ATTRS = [
14+
'_with_cube', '_with_rollup', '_with_totals', '_final_clause',
15+
'_sample_clause', '_limit_by_clause', '_array_join'
16+
]
17+
18+
def visit_select(
19+
self,
20+
select_stmt,
21+
**kwargs,
22+
):
23+
orig_compile_state_factory = select_stmt._compile_state_factory
24+
25+
def compile_state_factory(self, *args, **kwargs):
26+
result = orig_compile_state_factory(self, *args, **kwargs)
27+
28+
if hasattr(result, 'select_statement'):
29+
# Fix missed attributes
30+
for attr in ClickHouseSQLCompiler.CUSTOM_SELECT_ATTRS:
31+
val = getattr(result.select_statement, attr, None)
32+
33+
if val is not None:
34+
setattr(result.statement, attr, val)
35+
36+
return result
37+
38+
select_stmt._compile_state_factory = compile_state_factory
39+
return super().visit_select(select_stmt=select_stmt, **kwargs)
40+
1341
def visit_mod_binary(self, binary, operator, **kw):
1442
return self.process(binary.left, **kw) + ' %% ' + \
1543
self.process(binary.right, **kw)

clickhouse_sqlalchemy/orm/query.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@ class Query(BaseQuery):
1919
_limit_by = None
2020
_array_join = None
2121

22+
def _statement_20(self, for_statement=False, use_legacy_query_style=True):
23+
orig_smt = super(Query, self)._statement_20(
24+
for_statement=for_statement,
25+
use_legacy_query_style=use_legacy_query_style
26+
)
27+
28+
orig_smt._with_cube = self._with_cube
29+
orig_smt._with_rollup = self._with_rollup
30+
orig_smt._with_totals = self._with_totals
31+
orig_smt._final_clause = self._final
32+
orig_smt._sample_clause = sample_clause(self._sample)
33+
orig_smt._limit_by_clause = self._limit_by
34+
orig_smt._array_join = self._array_join
35+
36+
return orig_smt
37+
2238
def _compile_context(self, *args, **kwargs):
2339
context = super(Query, self)._compile_context(*args, **kwargs)
2440
query = context.query

0 commit comments

Comments
 (0)