-
Notifications
You must be signed in to change notification settings - Fork 1
/
app.py
56 lines (44 loc) · 1.69 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import pyarrow as pa
from pyarrow import flight
import pyarrow.parquet as pq
from datafusion import SessionContext
import json
class ExampleFlightServer(flight.FlightServerBase):
def do_get(self, context, ticket):
try:
ticket_obj = json.loads(ticket.ticket.decode())
sql_query = ticket_obj["sql"]
table_name = ticket_obj["table"]
ctx = SessionContext()
ctx.register_parquet(table_name, f"{table_name}.parquet")
result = ctx.sql(sql_query)
table = result.to_arrow_table()
return flight.RecordBatchStream(table)
except Exception as e:
print(e)
def do_put(self, context, descriptor, reader, writer):
table_name = descriptor.path[0].decode('utf-8')
data_table = reader.read_all()
file_path = f"{table_name}.parquet"
if os.path.exists(file_path):
try:
existing_table = pq.read_table(file_path)
data_table = pa.concat_tables([data_table, existing_table])
except Exception as e:
print(e)
try:
pq.write_table(data_table, file_path)
except Exception as e:
print(e)
def do_action(self, context, action):
bytes = json.dumps({"action":action.type}).encode()
result = flight.Result(bytes)
return [result]
def run_flight_server():
location = flight.Location.for_grpc_tcp("localhost", 8081)
server = ExampleFlightServer(location)
print("Starting Flight server on localhost:8081")
server.serve()
if __name__ == "__main__":
run_flight_server()