pystarburst analytics examples (querying aviation data part deux)

I had so much fun publishing pystarburst (the dataframe api) and running it in Starburst Galaxy that I wanted to share some more examples. Can you tell I’m pretty fired-up?!?! This time, I’m using the datasets and analytical questions posed in querying aviation data in the cloud (leveraging starburst galaxy) previously.

I encourage you to at least take a quick look through that last post, but I’ll provide a brief introduction of the datasets in this post. To set up tables for yourself so that you can run the PyStarburst example code, you’ll definitely want to follow the steps presented there.

More about DataFrames

In my first PyStarburst post I explicitly stated I was NOT “attempting to teach you everything you need to know about the DataFrame API” (and I’m still NOT trying to do that), but I’m realizing I should be a tiny bit nicer than that. Here’s a bit from the source to get you started.

A DataFrame is a Dataset organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood.

https://spark.apache.org/docs/latest/sql-programming-guide.html

The PyStarburst DataFrame API itself is the list of objects and methods available to Python programers to work with this VERY INTERESTING collection of rows. What makes it so interesting? Well… the collections are NOT collections of data. They are instructions for how to build (in a highly parallelized cluster) the DataFrames when, and only when, they are needed.

Yep, that makes about ZERO SENSE the first time you hear it. This is part of the “lazy execution” phrase you may have heard about regarding DataFrames. Basically, PyStarburst (and PySpark for that matter which this implementation’s API was based on) lets you write as much code as you want with the API (and create as many DataFrame objects as you need), but… no real data is read or written until a operation/function is called that requires an ACTION to be performed. Generally that means only when you want to retrieve or persist results.

Yes, it is STILL A LOT, but that’s all I’m going to tell you now. I just want you to TRUST ME that in the code below when you see an object that is a DataFrame, do NOT assume that any real work was done to fetch that data and bring it back to the Python program. I’ll mostly be using the show() function when I need to see some results. That is one of those “action” operations I mentioned above. The other functions are called “transformations” and create additional DataFrames (again, that just means that don’t really do any heavy lifting).

If you don’t feel TOTALLY LOST, feel free to check out my functional programming and big data API’s posts for more on this whole topic. 🙂 In fairness, it was CLEAR AS MUD for me when I first started working with Spark.

Create & load the tables

All the details for getting Starburst Galaxy with the tables you need are documented here. The following ERD should give you a rough idea of the tables and their logical relationships for this aviation-oriented domain.

Analyze the data

I’m going to port the 7 analytical questions raised here as well as the SQL used in that post. This time, I’ll do them with the DataFrame API. Let’s jump in!

Like before, I’m just running this from the CLI. I’m saving my code in a file called aviation.py and then running it by entering python3 aviation.py each time I want to run my code.

Here’s the boilerplate code again; including all the imports you’ll need for the code to come.

import trino
from pystarburst import Session
from pystarburst import functions as f
from pystarburst.functions import col, lag, round, row_number
from pystarburst.window import Window

db_parameters = {
    "host": "lXXXXXXXXXXXXr.trino.galaxy.starburst.io",
    "port": 443,
    "http_scheme": "https",
    "auth": trino.auth.BasicAuthentication("lXXXXXX/XXXXXXn", "<password>")
}
session = Session.builder.configs(db_parameters).create()

Q1: How many rows in the flight table?

SQL solution

SELECT count(*) 
  FROM mycloud.aviation.raw_flight;

Python solution

This is super simple. Just retrieve the raw_flight table as a DataFrame and then call the count() function (which returns an integer).

allFs = session.table("mycloud.aviation.raw_flight")
print(allFs.count())

Results

2056494

Q2: What country are most of the airports located in?

SQL solution

SELECT country, count() AS num_airports
  FROM mycloud.aviation.raw_airport
 GROUP BY country
 ORDER BY num_airports DESC;

Python solution

This one is pretty straightforward, too. After getting hold of the raw_airport table, I’m using a group_by() function and on those results performing a count() function on the aggregated rows. Finally, just order the results by the number of rows for each country and showing a single result.

# get the whole table, aggregate & sort
mostAs = session \
    .table("mycloud.aviation.raw_airport") \
	.group_by("country").count() \
	.sort("count", ascending=False)
mostAs.show(1)

Results

-----------------------
|"country"  |"count"  |
-----------------------
|USA        |3363     |
-----------------------

Q3: What are the top 5 airline codes with the most number of flights?

SQL solution

SELECT unique_carrier, count() as num_flights
  FROM mycloud.aviation.raw_flight
 GROUP BY unique_carrier 
 ORDER BY num_flights DESC
 LIMIT 5;

Python solution

This is VERY SIMILAR to Q2.

# get the whole table, aggregate & sort
mostFs = session \
    .table("mycloud.aviation.raw_flight") \
	.group_by("unique_carrier").count() \
	.rename("unique_carrier", "carr") \
	.sort("count", ascending=False)
mostFs.show(5)

Results

--------------------
|"carr"  |"count"  |
--------------------
|WN      |356167   |
|AA      |175969   |
|OO      |166445   |
|MQ      |141178   |
|US      |133403   |
--------------------

Q4: Same question, but show the airline carrier’s name.

SQL solution

SELECT c.description, count() as num_flights
  FROM mycloud.aviation.raw_flight  f 
  JOIN mycloud.aviation.raw_carrier c
    ON (f.unique_carrier = c.code)
 GROUP BY c.description 
 ORDER BY num_flights DESC
 LIMIT 5;

Python solution

You can create a DataFrame for the raw_carrier table to join on later. Then, just pick up where you left off in Q3 by chaining a few more methods on it; namely the join().

# get all of the carriers
allCs = session.table("mycloud.aviation.raw_carrier")

# repurpose mostFs from above (or chain on it) 
#   to join the 2 DFs and sort the results that
#   have already been grouped
top5CarrNm = mostFs \
    .join(allCs, mostFs.carr == allCs.code) \
    .drop("code") \
	.sort("count", ascending=False)
top5CarrNm.show(5, 30)

Results

-----------------------------------------------------
|"carr"  |"count"  |"description"                   |
-----------------------------------------------------
|WN      |356167   |Southwest Airlines Co.          |
|AA      |175969   |American Airlines Inc.          |
|OO      |166445   |Skywest Airlines Inc.           |
|MQ      |141178   |American Eagle Airlines Inc.    |
|US      |133403   |US Airways Inc. (Merged wit...  |
-----------------------------------------------------

Q5: What are the most common airplane models for flights over 1500 miles?

SQL solution

SELECT p.model, count() as num_flights
  FROM mycloud.aviation.raw_flight f 
  JOIN mycloud.aviation.raw_plane  p
    ON (f.tail_number = p.tail_number)
 WHERE f.distance > 1500
   AND p.model IS NOT NULL
 GROUP BY p.model
 ORDER BY num_flights desc
 LIMIT 10;

Python solution

# trimFs are flights projected & filtered
trimFs = session.table("mycloud.aviation.raw_flight") \
	.rename("tail_number", "tNbr") \
	.select("tNbr", "distance") \
	.filter(col("distance") > 1500) 

# trimPs are planes table projected & filtered
trimPs = session.table("mycloud.aviation.raw_plane") \
	.select("tail_number", "model") \
	.filter("model is not null")

# join, group & sort
q5Answer = trimFs \
	.join(trimPs, trimFs.tNbr == trimPs.tail_number) \
	.drop("tail_number") \
	.group_by("model").count() \
	.sort("count", ascending=False)	
q5Answer.show()

Results

----------------------
|"model"   |"count"  |
----------------------
|A320-232  |28926    |
|737-7H4   |21597    |
|757-222   |14609    |
|757-232   |12972    |
|737-824   |10789    |
|737-832   |9393     |
|A319-131  |5881     |
|A321-211  |4921     |
|767-332   |4522     |
|A319-132  |4480     |
----------------------

Q6: What is the month over month percentage change of number of flights departing from each airport?

SQL solution

This solution leveraged Common Table Expressions (CTE) which you could conceptualize as temporary tables. I’ll follow this general approach in the Python solution where I explain the code a bit more.

WITH agg_flights AS (
SELECT origination, month, 
       COUNT(*) AS num_flights
  FROM mycloud.aviation.raw_flight
 GROUP BY 1,2
),
 
change_flights AS (
SELECT origination, month, num_flights,
       LAG(num_flights, 1)
         OVER(PARTITION BY origination
                ORDER BY month ASC) 
           AS num_flights_before
  FROM agg_flights
)
 
SELECT origination, month, num_flights, num_flights_before,
       ROUND((1.0 * (num_flights - num_flights_before)) / 
             (1.0 * (num_flights_before)), 2)
          AS perc_change
  FROM change_flights;

Python solution

This first bit emulates the creation of the agg_flights CTE above.

# temp DF holds counts for each originating airport 
#   by month
aggFlights = session.table("mycloud.aviation.raw_flight") \
	.select("origination", "month") \
	.rename("origination", "orig") \
	.group_by("orig", "month").count() \
	.rename("count", "num_fs")

Then I created a Window definition that will help create a new column that is the number of flights from the prior record in the sorted list of all flights for each specific originating airport.

# define a window specification
w1 = Window.partition_by("orig").order_by("month")

# add col to grab the prior row's nbr flights
changeFlights = aggFlights \
	.withColumn("num_fs_b4", \
		lag("num_fs",1).over(w1))

Lastly, I determined the percentage change in the number of flights from the prior month.

# add col for the percentage change
q6Answer = changeFlights \
	.withColumn("perc_chg", \
		round((1.0 * (col("num_fs") - col("num_fs_b4")) / \
		      (1.0 * col("num_fs_b4"))), 1))
q6Answer.show()

Results

----------------------------------------------------------
|"orig"  |"month"  |"num_fs"  |"num_fs_b4"  |"perc_chg"  |
----------------------------------------------------------
|ABE     |1        |99        |NULL         |NULL        |
|ABE     |2        |111       |99           |0.1         |
|ABE     |3        |127       |111          |0.1         |
|ABE     |4        |142       |127          |0.1         |
|ABE     |5        |137       |142          |-0.0        |
|ABE     |6        |116       |137          |-0.2        |
|ABE     |7        |113       |116          |-0.0        |
|ABE     |8        |106       |113          |-0.1        |
|ABE     |9        |94        |106          |-0.1        |
|ABE     |10       |140       |94           |0.5         |
----------------------------------------------------------

Q7: Determine the top 3 routes departing from each airport.

SQL solution

This is another CTE solution and as in Q6, I’ll follow this approach in the Python solution.

WITH popular_routes AS (
SELECT origination, destination,
       COUNT(*) AS num_flights
  FROM raw_flight
 GROUP BY 1, 2
),
 
ranked_routes AS (
SELECT origination, destination,
       ROW_NUMBER() 
         OVER(PARTITION BY origination 
               ORDER BY num_flights DESC) 
           AS rank
  FROM popular_routes
)
 
SELECT origination, destination, rank
  FROM ranked_routes
  WHERE rank <= 3
  ORDER BY origination, rank;

Python solution

This first bit emulates the creation of the popular_routes CTE above.

# determine counts from orig>dest pairs
popularRoutes = session \
	.table("mycloud.aviation.raw_flight") \
	.rename("origination", "orig") \
	.rename("destination", "dest") \
	.group_by("orig", "dest").count() \
	.rename("count", "num_fs")

Then I created a Window definition that will help create a ranking value for all flights for an orginating airport sorted by the number of flights for each combination.

# define a window specification
w2 = Window.partition_by("orig") \
	.order_by(col("num_fs").desc())

# add col to put the curr row's ranking in
rankedRoutes = popularRoutes \
	.withColumn("rank", \
		row_number().over(w2))

Lastly, I just tossed out any ranking greater than 3 and sorted to show the top values for each originating airport.

# just show up to 3 for each orig airport
q7Answer = rankedRoutes \
	.filter(col("rank") <= 3) \
	.sort("orig", "rank")
q7Answer.show(17);

Results

---------------------------------------
|"orig"  |"dest"  |"num_fs"  |"rank"  |
---------------------------------------
|ABE     |ORD     |420       |1       |
|ABE     |DTW     |282       |2       |
|ABE     |ATL     |247       |3       |
|ABI     |DFW     |773       |1       |
|ABQ     |PHX     |1619      |1       |
|ABQ     |DEN     |1254      |2       |
|ABQ     |DAL     |951       |3       |
|ABY     |ATL     |338       |1       |
|ACK     |EWR     |62        |1       |
|ACK     |JFK     |58        |2       |
|ACT     |DFW     |567       |1       |
|ACV     |SFO     |705       |1       |
|ACV     |SMF     |175       |2       |
|ACV     |SLC     |134       |3       |
|ACY     |ATL     |34        |1       |
|ACY     |LGA     |1         |2       |
|ACY     |JFK     |1         |3       |
---------------------------------------

The code

Here is the code all in one file; aviation.py.

import trino
from pystarburst import Session
from pystarburst import functions as f
from pystarburst.functions import col, lag, round, row_number
from pystarburst.window import Window

db_parameters = {
    "host": "lXXXXXXXXXXXXr.trino.galaxy.starburst.io",
    "port": 443,
    "http_scheme": "https",
    "auth": trino.auth.BasicAuthentication("lXXXXXX/XXXXXXn", "<password>")
}
session = Session.builder.configs(db_parameters).create()


print("")
print("Q1 ---------------------------")
print("How many rows in the flight table?")

allFs = session.table("mycloud.aviation.raw_flight")
print(allFs.count())


print("")
print("Q2 ---------------------------")
print("What country are most of the airports")
print("  located in?")

# get the whole table, aggregate & sort
mostAs = session \
    .table("mycloud.aviation.raw_airport") \
	.group_by("country").count() \
	.sort("count", ascending=False)
mostAs.show(1)


print("")
print("Q3 ---------------------------")
print("What are the top 5 airline codes with ")
print("  the most number of flights?")

# get the whole table, aggregate & sort
mostFs = session \
    .table("mycloud.aviation.raw_flight") \
	.group_by("unique_carrier").count() \
	.rename("unique_carrier", "carr") \
	.sort("count", ascending=False)
mostFs.show(5)


print("")
print("Q4 ---------------------------")
print("Same question, but show the airline ") 
print("  carrier's name.")

# get all of the carriers
allCs = session.table("mycloud.aviation.raw_carrier")

# repurpose mostFs from above (or chain on it) 
#   to join the 2 DFs and sort the results that
#   have already been grouped
top5CarrNm = mostFs \
    .join(allCs, mostFs.carr == allCs.code) \
    .drop("code") \
	.sort("count", ascending=False)
top5CarrNm.show(5, 30)


print("")
print("Q5 ---------------------------")
print("What are the most common airplane models ") 
print("  for flights over 1500 miles?")

# trimFs are flights projected & filtered
trimFs = session.table("mycloud.aviation.raw_flight") \
	.rename("tail_number", "tNbr") \
	.select("tNbr", "distance") \
	.filter(col("distance") > 1500) 

# trimPs are planes table projected & filtered
trimPs = session.table("mycloud.aviation.raw_plane") \
	.select("tail_number", "model") \
	.filter("model is not null")

# join, group & sort
q5Answer = trimFs \
	.join(trimPs, trimFs.tNbr == trimPs.tail_number) \
	.drop("tail_number") \
	.group_by("model").count() \
	.sort("count", ascending=False)	
q5Answer.show()


print("")
print("Q6 ---------------------------")
print("What is the month over month percentage ")
print("  change of number of flights departing ")
print("  from each airport?")

# temp DF holds counts for each originating airport 
#   by month
aggFlights = session.table("mycloud.aviation.raw_flight") \
	.select("origination", "month") \
	.rename("origination", "orig") \
	.group_by("orig", "month").count() \
	.rename("count", "num_fs")

# define a window specification
w1 = Window.partition_by("orig").order_by("month")

# add col to grab the prior row's nbr flights
changeFlights = aggFlights \
	.withColumn("num_fs_b4", \
		lag("num_fs",1).over(w1))
	
# add col for the percentage change
q6Answer = changeFlights \
	.withColumn("perc_chg", \
		round((1.0 * (col("num_fs") - col("num_fs_b4")) / \
		      (1.0 * col("num_fs_b4"))), 1))
q6Answer.show()


print("")
print("Q7 ---------------------------")
print("Determine the top 3 routes departing from ")
print("  each airport. ")

# determine counts from orig>dest pairs
popularRoutes = session \
	.table("mycloud.aviation.raw_flight") \
	.rename("origination", "orig") \
	.rename("destination", "dest") \
	.group_by("orig", "dest").count() \
	.rename("count", "num_fs")

# define a window specification
w2 = Window.partition_by("orig") \
	.order_by(col("num_fs").desc())

# add col to put the curr row's ranking in
rankedRoutes = popularRoutes \
	.withColumn("rank", \
		row_number().over(w2))

# just show up to 3 for each orig airport
q7Answer = rankedRoutes \
	.filter(col("rank") <= 3) \
	.sort("orig", "rank")
q7Answer.show(17);

Published by lestermartin

Developer advocate, trainer, blogger, and data engineer focused on data lake & streaming frameworks including Trino, Hive, Spark, Flink, Kafka and NiFi.

Leave a Reply

Discover more from Lester Martin (l11n)

Subscribe now to keep reading and get access to the full archive.

Continue reading