Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
When loading data into the database, messages will be sorted into SQL tables determined by the message type and month. The names of these tables follow the following format, which {YYYYMM} indicates the table year and month in the format YYYYMM.
ais_{YYYYMM}_static # table with static AIS messages
ais_{YYYYMM}_dynamic # table with dynamic AIS messageSome additional tables containing computed data may be created depending on the indexes used. For example, an aggregate of vessel static data by month or a virtual table is used as a covering index.
static_{YYYYMM}_aggregate # table of aggregated static vessel dataAdditional tables are also included for storing data not directly derived from AIS message reports.
coarsetype_ref # a reference table that maps numeric ship type codes to their descriptions hashmapFor quick reference to data types and detailed explanations of these table entries, please see the Detailed Table Description.
In addition to querying the database using DBQuery module, there is an option to customize the query with your own SQL code.
Example of listing all the tables in your database:
import sqlite3
dbpath='YOUR_DATABASE.db' # Define the path to your database
# Connect to the database
connection = sqlite3.connect(dbpath)
# Create a cursor object
cursor = connection.cursor()
# Query to list all tables
query = "SELECT name FROM sqlite_master WHERE type='table';"
cursor.execute(query)
# Fetch the results
tables = cursor.fetchall()
# Print the names of the tables
print("Tables in the database:")
for table in tables:
print(table[0])
# Close the connection
connection.close()As messages are separated into tables by message type and month, queries spanning multiple message types or months should use UNIONs and JOINs to combine results as appropriate.
Example of querying tables with `JOIN`:
import sqlite3
# Connect to the database
connection = sqlite3.connect('YOUR_DATABASE.db')
# Create a cursor object
cursor = connection.cursor()
# Define the JOIN SQL query
query = f"""
SELECT
d.mmsi,
d.time,
d.longitude,
d.latitude,
d.sog,
d.cog,
s.vessel_name,
s.ship_type
FROM ais_{YYYYMM}_dynamic d
LEFT JOIN ais_{YYYYMM}_static s ON d.mmsi = s.mmsi
WHERE d.time BETWEEN 1707033659 AND 1708176856 -- Filter by time range
AND d.longitude BETWEEN -68 AND -56 -- Filter by geographical area
AND d.latitude BETWEEN 45 AND 51.5;
"""
# Execute the query
cursor.execute(query)
# Fetch the results
results = cursor.fetchall()
# Print the results
for row in results:
print(row)
# Close the connection
connection.close()More information about SQL queries can be looked up from online tutorials.
The R* tree virtual tables should be queried for AIS position reports instead of the default tables. Query performance can be significantly improved using the R* tree index when restricting output to a narrow range of MMSIs, timestamps, longitudes, and latitudes. However, querying a wide range will not yield much benefit. If custom indexes are required for specific manual queries, these should be defined on message tables 1_2_3, 5, 18, and 24 directly instead of upon the virtual tables.
Timestamps are stored as epoch minutes in the database. To facilitate querying the database manually, use the dt_2_epoch() function to convert datetime values to epoch minutes and the epoch_2_dt() function to convert epoch minutes back to datetime values. Here is how you can use dt_2_epoch() with the example above:
from aisdb.gis import dt_2_epoch
# Define the datetime range
start_datetime = datetime(2018, 1, 1, 0, 0, 0)
end_datetime = datetime(2018, 1, 1, 1, 59, 59)
# Convert datetime to epoch time
start_epoch = dt_2_epoch(start_datetime)
end_epoch = dt_2_epoch(end_datetime)
# Connect to the database
connection = sqlite3.connect('YOUR_DATABASE.db')
# Create a cursor object
cursor = connection.cursor()
# Define the JOIN SQL query using an epoch time range
query = f"""
SELECT
d.mmsi,
d.time,
d.longitude,
d.latitude,
d.sog,
d.cog,
s.vessel_name,
s.ship_type
FROM ais_201801_dynamic d
LEFT JOIN ais_201801_static s ON d.mmsi = s.mmsi
WHERE d.time BETWEEN {start_epoch} AND {end_epoch} -- Filter by time range
AND d.longitude BETWEEN -68 AND -56 -- Filter by geographical area
AND d.latitude BETWEEN 45 AND 51.5;
"""
# Execute the query
cursor.execute(query)
# Fetch the results
results = cursor.fetchall()
# Print the results
for row in results:
print(row)
# Close the connection
connection.close()For more examples, please see the SQL code in aisdb_sql/ that is used to create database tables and associated queries.
ais_{YYYYMM}_dynamic tablesmmsi
INTEGER
Maritime Mobile Service Identity, a unique identifier for vessels.
time
INTEGER
Timestamp of the AIS message, in epoch seconds.
longitude
REAL
Longitude of the vessel in decimal degrees.
latitude
REAL
Latitude of the vessel in decimal degrees.
rot
REAL
Rate of turn, indicating how fast the vessel is turning.
sog
REAL
Speed over ground, in knots.
cog
REAL
Course over ground, in degrees.
heading
REAL
Heading of the vessel, in degrees.
maneuver
BOOLEAN
Indicator for whether the vessel is performing a special maneuver.
utc_second
INTEGER
Second of the UTC timestamp when the message was generated.
source
TEXT
Source of the AIS data.
ais_{YYYYMM}_static tablesmmsi
INTEGER
Maritime Mobile Service Identity, a unique identifier for vessels.
time
INTEGER
Timestamp of the AIS message, in epoch seconds.
vessel_name
TEXT
Name of the vessel.
ship_type
INTEGER
Numeric code representing the type of ship.
call_sign
TEXT
International radio call sign of the vessel.
imo
INTEGER
International Maritime Organization number, another unique vessel identifier.
dim_bow
INTEGER
Distance from the AIS transmitter to the bow (front) of the vessel.
dim_stern
INTEGER
Distance from the AIS transmitter to the stern (back) of the vessel.
dim_port
INTEGER
Distance from the AIS transmitter to the port (left) side of the vessel.
dim_star
INTEGER
Distance from the AIS transmitter to the starboard (right) side of the vessel.
draught
REAL
Maximum depth of the vessel's hull below the waterline, in meters.
destination
TEXT
Destination port or location where the vessel is heading.
ais_version
INTEGER
AIS protocol version used by the vessel.
fixing_device
TEXT
Type of device used for fixing the vessel's position (e.g., GPS).
eta_month
INTEGER
Estimated time of arrival month.
eta_day
INTEGER
Estimated time of arrival day.
eta_hour
INTEGER
Estimated time of arrival hour.
eta_minute
INTEGER
Estimated time of arrival minute.
source
TEXT
Source of the AIS data (e.g., specific AIS receiver or data provider).
static_{YYYYMM}_aggregate tablesmmsi
INTEGER
Maritime Mobile Service Identity, a unique identifier for vessels.
imo
INTEGER
International Maritime Organization number, another unique vessel identifier.
vessel_name
TEXT
Name of the vessel.
ship_type
INTEGER
Numeric code representing the type of ship.
call_sign
TEXT
International radio call sign of the vessel.
dim_bow
INTEGER
Distance from the AIS transmitter to the bow (front) of the vessel.
dim_stern
INTEGER
Distance from the AIS transmitter to the stern (back) of the vessel.
dim_port
INTEGER
Distance from the AIS transmitter to the port (left) side of the vessel.
dim_star
INTEGER
Distance from the AIS transmitter to the starboard (right) side of the vessel.
draught
REAL
Maximum depth of the vessel's hull below the waterline, in meters.
destination
TEXT
Destination port or location where the vessel is heading.
eta_month
INTEGER
Estimated time of arrival month.
eta_day
INTEGER
Estimated time of arrival day.
eta_hour
INTEGER
Estimated time of arrival hour.
eta_minute
INTEGER
Estimated time of arrival minute.
In addition to accessing data stored on the AISdb server, you can download open-source AIS data or import your datasets for processing and analysis using AISdb. This tutorial guides you through downloading AIS data from popular websites, creating SQLite and PostgreSQL databases compatible with AISdb, and establishing database connections. We provide two examples: Downloading and Processing Individual Files, which demonstrates working with small data samples and creating an SQLite database, and Pipeline for Bulk File Downloads and Database Integration, which outlines our approach to handling multiple data file downloads and creating a PostgreSQL database.
The U.S. vessel traffic data across user-defined geographies and periods are available at MarineCadastre. This resource offers comprehensive AIS data that can be accessed for various maritime analysis purposes. We can tailor the dataset based on research needs by selecting specific regions and timeframes.
In the following example, we will show how to download and process a single data file and import the data to a newly created SQLite database.
First, download the AIS data of the day using the curl command:
curl -o ./data/AIS_2020_01_01.zip https://coast.noaa.gov/htdata/CMSP/AISDataHandler/2020/AIS_2020_01_01.zipThen, extract the downloaded ZIP file to a specific path:
unzip ./data/AIS_2020_01_01.zip -d ./data/We will look into the number of columns in the downloaded CSV file.
import pandas as pd
# Read CSV file in pandas dataframe
df_ = pd.read_csv("./data/AIS_2020_01_01.csv", parse_dates=["BaseDateTime"])
print(df_.columns)Index(['MMSI', 'BaseDateTime', 'LAT', 'LON', 'SOG', 'COG', 'Heading',
'VesselName', 'IMO', 'CallSign', 'VesselType', 'Status',
'Length', 'Width', 'Draft', 'Cargo', 'TransceiverClass'],
dtype='object')The required columns for AISdb have specific names and may differ from the imported dataset. Therefore, let's define the exact list of columns needed.
list_of_headers_ = ["MMSI","Message_ID","Repeat_indicator","Time","Millisecond","Region","Country","Base_station","Online_data","Group_code","Sequence_ID","Channel","Data_length","Vessel_Name","Call_sign","IMO","Ship_Type","Dimension_to_Bow","Dimension_to_stern","Dimension_to_port","Dimension_to_starboard","Draught","Destination","AIS_version","Navigational_status","ROT","SOG","Accuracy","Longitude","Latitude","COG","Heading","Regional","Maneuver","RAIM_flag","Communication_flag","Communication_state","UTC_year","UTC_month","UTC_day","UTC_hour","UTC_minute","UTC_second","Fixing_device","Transmission_control","ETA_month","ETA_day","ETA_hour","ETA_minute","Sequence","Destination_ID","Retransmit_flag","Country_code","Functional_ID","Data","Destination_ID_1","Sequence_1","Destination_ID_2","Sequence_2","Destination_ID_3","Sequence_3","Destination_ID_4","Sequence_4","Altitude","Altitude_sensor","Data_terminal","Mode","Safety_text","Non-standard_bits","Name_extension","Name_extension_padding","Message_ID_1_1","Offset_1_1","Message_ID_1_2","Offset_1_2","Message_ID_2_1","Offset_2_1","Destination_ID_A","Offset_A","Increment_A","Destination_ID_B","offsetB","incrementB","data_msg_type","station_ID","Z_count","num_data_words","health","unit_flag","display","DSC","band","msg22","offset1","num_slots1","timeout1","Increment_1","Offset_2","Number_slots_2","Timeout_2","Increment_2","Offset_3","Number_slots_3","Timeout_3","Increment_3","Offset_4","Number_slots_4","Timeout_4","Increment_4","ATON_type","ATON_name","off_position","ATON_status","Virtual_ATON","Channel_A","Channel_B","Tx_Rx_mode","Power","Message_indicator","Channel_A_bandwidth","Channel_B_bandwidth","Transzone_size","Longitude_1","Latitude_1","Longitude_2","Latitude_2","Station_Type","Report_Interval","Quiet_Time","Part_Number","Vendor_ID","Mother_ship_MMSI","Destination_indicator","Binary_flag","GNSS_status","spare","spare2","spare3","spare4"]Next, we update the name of columns in the existing dataframe df_ and change the time format as required. The timestamp of an AIS message is represented by BaseDateTime in the default format YYYY-MM-DDTHH:MM:SS. For AISdb, however, the time is represented in UNIX format. We now read the CSV and apply the necessary changes to the date format:
# Take the first 40,000 records from the original dataframe
df = df_.iloc[0:40000]
# Create a new dataframe with the specified headers
df_new = pd.DataFrame(columns=list_of_headers_)
# Populate the new dataframe with formatted data from the original dataframe
df_new['Time'] = pd.to_datetime(df['BaseDateTime']).dt.strftime('%Y%m%d_%H%M%S')
df_new['Latitude'] = df['LAT']
df_new['Longitude'] = df['LON']
df_new['Vessel_Name'] = df['VesselName']
df_new['Call_sign'] = df['CallSign']
df_new['Ship_Type'] = df['VesselType'].fillna(0).astype(int)
df_new['Navigational_status'] = df['Status']
df_new['Draught'] = df['Draft']
df_new['Message_ID'] = 1 # Mark all messages as dynamic by default
df_new['Millisecond'] = 0
# Transfer additional columns from the original dataframe, if they exist
for col_n in df_new:
if col_n in df.columns:
df_new[col_n] = df[col_n]
# Extract static messages for each unique vessel
filtered_df = df_new[df_new['Ship_Type'].notnull() & (df_new['Ship_Type'] != 0)]
filtered_df = filtered_df.drop_duplicates(subset='MMSI', keep='first')
filtered_df = filtered_df.reset_index(drop=True)
filtered_df['Message_ID'] = 5 # Mark these as static messages
# Merge dynamic and static messages into a single dataframe
df_new = pd.concat([filtered_df, df_new])
# Save the final dataframe to a CSV file
# The quoting parameter is necessary because the csvreader reads each column value as a string by default
df_new.to_csv("./data/AIS_2020_01_01_aisdb.csv", index=False, quoting=1)In the code, we can see that we have mapped the column named accordingly. Additionally, the data type of some columns has also been changed. Additionally, the nm4 file usually contains raw messages, separating static messages from dynamic ones. However, the MarineCadastre Data does not have such a Message_ID to indicate the type. Thus, adding static messages is necessary for database creation so that a table related to metadata is created.
Let's process the CSV to create an SQLite database using the aisdb package.
import aisdb
# Establish a connection to the SQLite database and decode messages from the CSV file
with aisdb.SQLiteDBConn('./data/test_decode_msgs.db') as dbconn:
aisdb.decode_msgs(filepaths=["./data/AIS_2020_01_01_aisdb.csv"],
dbconn=dbconn, source='Testing', verbose=True)generating file checksums...
checking file dates...
creating tables and dropping table indexes...
Memory: 20.65GB remaining. CPUs: 12. Average file size: 49.12MB Spawning 4 workers
saving checksums...
processing ./data/AIS_2020_01_01_aisdb.csv
AIS_2020_01_01_aisdb.csv count: 49323 elapsed: 0.27s rate: 183129 msgs/s
cleaning temporary data...
aggregating static reports into static_202001_aggregate...A SQLite database has been created now.
sqlite3 ./data/test_decode_msgs.db
sqlite> .tables
ais_202001_dynamic coarsetype_ref static_202001_aggregate
ais_202001_static hashmap If prefer to progress to PostgreSQL database, defining postgresql string and progress with database connection:
// Some codeThis section provides an example of downloading and processing multiple files, creating a PostgreSQL database, and loading data into tables. The steps are outlined in a series of pipeline scripts available in this GitHub repository, which should be executed in the order indicated by their numbers.
The first script, 0-download-ais.py, allows you to download AIS data from MarineCadastre by specifying your needed years. If no years are specified, the script will default to downloading data for 2023. The downloaded ZIP files will be stored in a /data folder created in your current working directory. The second script, 1-zip2csv.py, extracts the CSV files from the downloaded ZIP files in /data and saves them in a new directory named /zip.
To download and extract the data, simply run the two scripts in sequence:
python 0-download-ais.py
python 1-zip2csv.pyAfter downloading and extracting the AIS data, the 2-merge.py script consolidates the daily CSV files into monthly files while the 3-deduplicate.py script removes duplicate rows, retaining unique AIS messages. To perform the execution, simply run:
python 2-merge.py
python 3-deduplicate.pyThe output of these two scripts will be cleaned CSV files, which will be stored in a new folder named /merged on your working directory.
The final script, 4-postgresql-database.py, creates a PostgreSQL database with a specified name. To do this, the script connects to a PostgreSQL server, requiring you to provide your username and password to establish the connection. After creating the database, the script verifies that the number of columns in the CSV files matches the headers. The script creates a corresponding table in the database for each CSV file and loads the data into it. To run this script, you need to provide three command-line arguments: -dbname for the new database name, -user for your PostgreSQL username, and -password for your PostgreSQL password. Additionally, there are two optional arguments: -host (default is localhost) and -port (default is 5432), you can adjust the -host and -port values if your PostgreSQL server is running on a different host or port.
python 4-postgresql-database.py -dbname DBNAME -user USERNAME -password PASSWORD [-host HOST] [-port PORT]When the program prompts that the task is finished, you may check the created database and loaded tables by connecting to the PostgreSQL server and using the psql command-line interface:
psql -U USERNAME -d DBNAME -h localhost -p 5432Once connected, you can list all tables in the database by running the \dt command. In our example using 2023 AIS data (default download), the tables will appear as follows:
ais_pgdb=# \dt
List of relations
Schema | Name | Type | Owner
--------+-------------+-------+----------
public | ais_2023_01 | table | postgres
public | ais_2023_02 | table | postgres
public | ais_2023_03 | table | postgres
public | ais_2023_04 | table | postgres
public | ais_2023_05 | table | postgres
public | ais_2023_06 | table | postgres
public | ais_2023_07 | table | postgres
public | ais_2023_08 | table | postgres
public | ais_2023_09 | table | postgres
public | ais_2023_10 | table | postgres
public | ais_2023_11 | table | postgres
public | ais_2023_12 | table | postgres
(12 rows) This tutorial demonstrates how to access vessel metadata using MMSI and SQLite databases. In many cases, AIS messages do not contain metadata. Therefore, this tutorial introduces the built-in functions in AISdb and external APIs to extract detailed vessel information associated with a specific MMSI from web sources.
We introduced two methods implemented in AISdb for scraping metadata: using session requests for direct access and employing web drivers with browsers to handle modern websites with dynamic content. Additionally, we provided an example of utilizing a third-party API to access vessel information.
The session request method in Python is a straightforward and efficient approach for retrieving metadata from websites. In AISdb, the aisdb.webdata._scraper.search_metadata_vesselfinder function leverages this method to scrape detailed information about vessels based on their MMSI numbers. This function efficiently gathers a range of data, including vessel name, type, flag, tonnage, and navigation status.
This is an example of how to use the search_metadata_vesselfinder feature in AISdb to scrape data from VesselFinder website:
from aisdb.webdata._scraper import search_metadata_vesselfinder
MMSI = 228386800
dict_ = search_metadata_vesselfinder(MMSI)
print(dict_){'IMO number': '9839131',
'Vessel Name': 'CMA CGM CHAMPS ELYSEES',
'Ship type': 'Container Ship',
'Flag': 'France',
'Homeport': '-',
'Gross Tonnage': '236583',
'Summer Deadweight (t)': '220766',
'Length Overall (m)': '400',
'Beam (m)': '61',
'Draught (m)': '',
'Year of Build': '2020',
'Builder': '',
'Place of Build': '',
'Yard': '',
'TEU': '',
'Crude Oil (bbl)': '-',
'Gas (m3)': '-',
'Grain': '-',
'Bale': '-',
'Classification Society': '',
'Registered Owner': '',
'Owner Address': '',
'Owner Website': '-',
'Owner Email': '-',
'Manager': '',
'Manager Address': '',
'Manager Website': '',
'Manager Email': '',
'Predicted ETA': '',
'Distance / Time': '',
'Course / Speed': '\xa0',
'Current draught': '16.0 m',
'Navigation Status': '\nUnder way\n',
'Position received': '\n22 mins ago \n\n\n',
'IMO / MMSI': '9839131 / 228386800',
'Callsign': 'FLZF',
'Length / Beam': '399 / 61 m'}In addition to metadata scraping, we may also use the available API the data provides. MarineTraffic offers an option to subscribe to its API to access vessel data, forecast voyages, position the vessels, etc. Here is an example of retrieving :
import requests
# Your MarineTraffic API key
api_key = 'your_marine_traffic_api_key'
# List of MMSI numbers you want to query
mmsi = [228386800,
372351000,
373416000,
477003800,
477282400
]
# Base URL for the MarineTraffic API endpoint
url = f'https://services.marinetraffic.com/api/exportvessels/{api_key}'
# Prepare the API request
params = {
'shipid': ','.join(mmsi_list), # Join MMSI list with commas
'protocol': 'jsono', # Specify the response format
'msgtype': 'extended' # Specify the level of details
}
# Make the API request
response = requests.get(url, params=params)
# Check if the request was successful
if response.status_code == 200:
vessel_data = response.json()
for vessel in vessel_data:
print(f"Vessel Name: {vessel.get('NAME')}")
print(f"MMSI: {vessel.get('MMSI')}")
print(f"IMO: {vessel.get('IMO')}")
print(f"Call Sign: {vessel.get('CALLSIGN')}")
print(f"Type: {vessel.get('TYPE_NAME')}")
print(f"Flag: {vessel.get('COUNTRY')}")
print(f"Length: {vessel.get('LENGTH')}")
print(f"Breadth: {vessel.get('BREADTH')}")
print(f"Year Built: {vessel.get('YEAR_BUILT')}")
print(f"Status: {vessel.get('STATUS_NAME')}")
print('-' * 40)
else:
print(f"Failed to retrieve data: {response.status_code}")If you already have a database containing AIS track data, then vessel metadata can be downloaded and stored in a separate database.
from aisdb import track_gen, decode_msgs, DBQuery, sqlfcn_callbacks, Domain
from datetime import datetime
dbpath = "/home/database.db"
start = datetime(2021, 11, 1)
end = datetime(2021, 11, 2)
with DBConn(dbpath="/home/data_sample_dynamic.csv.db") as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, callback=in_timerange,
start=datetime(2020, 1, 1, hour=0),
end=datetime(2020, 12, 3, hour=2)
)
# A new database will be created if it does not exist to save the downloaded info from MarineTraffic
traffic_database_path = "/home/traffic_info.db"
# User can select a custom boundary for a query using aisdb.Domain
qry.check_marinetraffic(trafficDBpath=, boundary={"xmin":-180, "xmax":180, "ymin":-180, "ymax":180})
rowgen = qry.gen_qry(verbose=True)
trackgen = track_gen.TrackGen(rowgen, decimate=True)In AISdb, the speed of a vessel is calculated using the aisdb.gis.delta_knots function, which computes the speed over ground (SOG) in knots between consecutive positions within a given track. This calculation is important for the denoising encoder, as it compares the vessel's speed against a set threshold to aid in the data cleaning process.
Vessel speed calculation requires the distance the vessel has traveled between two consecutive positions and the time interval. This distance is computed using the haversine distance function, and the time interval is simply the difference in timestamps between the two consecutive AIS position reports. The speed is then computed using the formula:
The factor 1.9438445 converts the speed from meters per second to knots, the standard speed unit used in maritime contexts.
With the example track we created in Haversine Distance, we can calculate the vessel speed between each two consecutive positions:
import aisdb
import numpy as np
from datetime import datetime
from aisdb.gis import dt_2_epoch
# Generate example track
y1, x1 = 44.57039426840729, -63.52931373766157
y2, x2 = 44.51304767533133, -63.494075674952555
y3, x3 = 44.458038982492134, -63.535634138077945
y4, x4 = 44.393941339104074, -63.53826396955358
y5, x5 = 44.14245580737021, -64.16608964280064
t1 = dt_2_epoch( datetime(2021, 1, 1, 1) )
t2 = dt_2_epoch( datetime(2021, 1, 1, 2) )
t3 = dt_2_epoch( datetime(2021, 1, 1, 3) )
t4 = dt_2_epoch( datetime(2021, 1, 1, 4) )
t5 = dt_2_epoch( datetime(2021, 1, 1, 7) )
# Create a sample track
tracks_short = [
dict(
mmsi=123456789,
lon=np.array([x1, x2, x3, x4, x5]),
lat=np.array([y1, y2, y3, y4, y5]),
time=np.array([t1, t2, t3, t4, t5]),
dynamic=set(['lon', 'lat', 'time']),
static=set(['mmsi'])
)
]
# Calculate the vessel speed in knots
for track in tracks_short:
print(aisdb.gis.delta_knots(track))[3.7588560005768947 3.7519408684140214 3.8501088005116215 10.309565520121597]Building on the previous section, where we used AIS data to create AISdb databases, users can export AIS data from these databases into CSV format. In this section, we provide examples of exporting data from SQLite or PostgreSQL databases into CSV files. While we demonstrate these operations using internal data, you can apply the same techniques to your databases.
In the first example, we connected to a SQLite database, queried data in a specific time range and area of interest, and then exported the queried data to a CSV file:
import csv
import aisdb
import nest_asyncio
from aisdb import DBConn, DBQuery, DomainFromPoints
from aisdb.database.dbconn import SQLiteDBConn
from datetime import datetime
nest_asyncio.apply()
dbpath = 'YOUR_DATABASE.db' # Path to your database
end_time = datetime.strptime("2018-01-02 00:00:00", '%Y-%m-%d %H:%M:%S')
start_time = datetime.strptime("2018-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
domain = DomainFromPoints(points=[(-63.6, 44.6)], radial_distances=[50000])
# Connect to SQLite database
dbconn = SQLiteDBConn(dbpath=dbpath)
with SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
)
tracks = aisdb.track_gen.TrackGen(qry.gen_qry(), decimate=False)
# Define the headers for the CSV file
headers = ['mmsi', 'time', 'lon', 'lat', 'cog', 'sog',
'utc_second', 'heading', 'rot', 'maneuver']
# Open the CSV file for writing
csv_filename = 'output_sqlite.csv'
with open(csv_filename, mode='w', newline='') as file:
writer = csv.DictWriter(file, fieldnames=headers)
writer.writeheader() # Write the header once
for track in tracks:
for i in range(len(track['time'])):
row = {
'rot': track['rot'],
'mmsi': track['mmsi'],
'lon': track['lon'][i],
'lat': track['lat'][i],
'cog': track['cog'][i],
'sog': track['sog'][i],
'time': track['time'][i],
'heading': track['heading'],
'maneuver': track['maneuver'],
'utc_second': track['utc_second'][i],
}
writer.writerow(row) # Write the row to the CSV file
print(f"All tracks have been combined and written to {csv_filename}")Now we can check the data in the exported CSV file:
mmsi time lon lat cog sog utc_second heading rot maneuver
0 219014000 1514767484 -63.537167 44.635834 322 0.0 44 295.0 0.0 0
1 219014000 1514814284 -63.537167 44.635834 119 0.0 45 295.0 0.0 0
2 219014000 1514829783 -63.537167 44.635834 143 0.0 15 295.0 0.0 0
3 219014000 1514829843 -63.537167 44.635834 171 0.0 15 295.0 0.0 0
4 219014000 1514830042 -63.537167 44.635834 3 0.0 35 295.0 0.0 0Similar to exporting data from a SQLite database to a CSV file, the only difference this time is that you'll need to connect to your PostgreSQL database and query the data you want to export to CSV. We showed a full example as follows:
import csv
import aisdb
import nest_asyncio
from datetime import datetime
from aisdb.database.dbconn import PostgresDBConn
from aisdb import DBConn, DBQuery, DomainFromPoints
nest_asyncio.apply()
dbconn = PostgresDBConn(
host='localhost', # PostgreSQL address
port=5432, # PostgreSQL port
user='your_username', # PostgreSQL username
password='your_password', # PostgreSQL password
dbname='database_name' # Database name
)
qry = DBQuery(
dbconn=dbconn,
start=datetime(2023, 1, 1), end=datetime(2023, 1, 3),
xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi
)
tracks = aisdb.track_gen.TrackGen(qry.gen_qry(), decimate=False)
# Define the headers for the CSV file
headers = ['mmsi', 'time', 'lon', 'lat', 'cog', 'sog',
'utc_second', 'heading', 'rot', 'maneuver']
# Open the CSV file for writing
csv_filename = 'output_postgresql.csv'
with open(csv_filename, mode='w', newline='') as file:
writer = csv.DictWriter(file, fieldnames=headers)
writer.writeheader() # Write the header once
for track in tracks:
for i in range(len(track['time'])):
row = {
'rot': track['rot'],
'mmsi': track['mmsi'],
'lon': track['lon'][i],
'lat': track['lat'][i],
'cog': track['cog'][i],
'sog': track['sog'][i],
'time': track['time'][i],
'heading': track['heading'],
'maneuver': track['maneuver'],
'utc_second': track['utc_second'][i],
}
writer.writerow(row) # Write the row to the CSV file
print(f"All tracks have been combined and written to {csv_filename}")We can check the output CSV file now:
mmsi time lon lat cog sog utc_second heading rot maneuver
0 210108000 1672545711 -63.645 44.68833 173 0.0 0 0.0 0.0 False
1 210108000 1672545892 -63.645 44.68833 208 0.0 0 0.0 0.0 False
2 210108000 1672546071 -63.645 44.68833 176 0.0 0 0.0 0.0 False
3 210108000 1672546250 -63.645 44.68833 50 0.0 0 0.0 0.0 False
4 210108000 1672546251 -63.645 44.68833 50 0.0 0 0.0 0.0 FalseA hands-on quick start guide for using AISdb.
Note: If you are starting from scratch, download the data ".db" file in our AISdb Tutorial GitHub repository so that you can follow this guide properly.
To work with the AISdb Python package, please ensure you have Python version 3.8 or higher. If you plan to use SQLite, no additional installation is required, as it is included with Python by default. However, those who prefer using a PostgreSQL server must install it separately and enable the TimescaleDB extension to function correctly.
The AISdb Python package can be conveniently installed using pip. It's highly recommended that a virtual Python environment be created and the package installed within it.
python -m venv AISdb # create a python virtual environment
source ./AISdb/bin/activate # activate the virtual environment
pip install aisdb # from https://pypi.org/project/aisdb/python -m venv AISdb
./AISdb/Scripts/activate
pip install aisdbYou can test your installation by running the following commands:
python
>>> import aisdb
>>> aisdb.__version__ # should return '1.7.3' or newerNotice that if you are running Jupyter, ensure it is installed in the same environment as AISdb:
source ./AISdb/bin/activate
pip install jupyter
jupyter notebookThe Python code in the rest of this document can be run in the Python environment you created.
For using nightly builds (not mandatory), you can install it from the source:
source AISdb/bin/activate # On Windows use `AISdb\Scripts\activate`
# Cloning the Repository and installing the package
git clone https://github.com/AISViz/AISdb.git && cd aisdb
# Windows users can instead download the installer:
# - https://forge.rust-lang.org/infra/other-installation-methods.html#rustup
# - https://static.rust-lang.org/rustup/dist/i686-pc-windows-gnu/rustup-init.exe
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs > install-rust.sh
# Installing Rust and Maturin
/bin/bash install-rust.sh -q -y
pip install --upgrade maturin[patchelf]
# Building AISdb package with Maturin
maturin develop --release --extras=test,docsAlternatively, you can use nightly builds (not mandatory) on Google Colab as follows:
import os
# Clone the AISdb repository from GitHub
!git clone https://github.com/AISViz/AISdb.git
# Install Rust using the official Rustup script
!curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
# Install Maturin to build the packages
!pip install --upgrade maturin[patchelf]
# Set up environment variables
os.environ["PATH"] += os.pathsep + "/root/.cargo/bin"
# Install wasm-pack for building WebAssembly packages
!curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
# Install wasm-pack as a Cargo dependency
!cargo install wasm-pack
# Setting environment variable for the virtual environment
os.environ["VIRTUAL_ENV"] = "/usr/local"
# Change directory to AISdb for building the package
%cd AISdb
# Build and install the AISdb package using Maturin
!maturin develop --release --extras=test,docsAISdb supports SQLite and PostgreSQL databases. Since version 1.7.3, AISdb requires TimescaleDB over PostgreSQL to function properly. To install TimescaleDB, follow these steps:
Install TimescaleDB (PostgreSQL Extension)
$ sudo apt install -y timescaledb-postgresql-XX # XX is the PG-SQL versionEnable the Extension in PostgreSQL
> CREATE EXTENSION IF NOT EXISTS timescaledb;Verify the Installation
> SELECT * FROM timescaledb_information.version;Restart PostgreSQL
$ sudo systemctl restart postgresqlThis option requires an optional dependency psycopg for interfacing with Postgres databases. Beware that Postgres accepts these keyword arguments. Alternatively, a connection string may be used. Information on connection strings and Postgres URI format can be found here.
from aisdb.database.dbconn import PostgresDBConn
# [OPTION 1]
dbconn = PostgresDBConn(
hostaddr='127.0.0.1', # Replace this with the Postgres address (supports IPv6)
port=5432, # Replace this with the Postgres running port (if not the default)
user='USERNAME', # Replace this with the Postgres username
password='PASSWORD', # Replace this with your password
dbname='DATABASE', # Replace this with your database name
)
# [OPTION 2]
dbconn = PostgresDBConn('postgresql://USERNAME:PASSWORD@HOST:PORT/DATABASE')Querying SQLite is as easy as informing the name of a ".db" file with the same entity-relationship as the databases supported by AIS, which are detailed in the SQL Database section. We prepared an example SQLite database example_data.db based AIS data in a small region near Maine, United States in Jan 2022 from Marine Cadastre, which is available in AISdb Tutorial GitHub repository.
from aisdb.database.dbconn import SQLiteDBConn
dbpath='example_data.db'
dbconn = SQLiteDBConn(dbpath=dbpath)If you want to create your database using your data, we have a tutorial with examples that show you how to create an SQLite database from open-source data.
Parameters for the database query can be defined using aisdb.database.dbqry.DBQuery. Iterate over rows returned from the database for each vessel with aisdb.database.dbqry.DBQuery.gen_qry(). Convert the results into generator-yielding dictionaries with NumPy arrays describing position vectors, e.g., lon, lat, and time, using aisdb.track_gen.TrackGen().
The following query will return vessel trajectories from a given 1-hour time window:
import aisdb
import pandas as pd
from datetime import datetime
from collections import defaultdict
dbpath = 'example_data.db'
start_time = datetime.strptime("2022-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
end_time = datetime.strptime("2022-01-01 0:59:59", '%Y-%m-%d %H:%M:%S')
def data2frame(tracks):
# Dictionary where for key/value
ais_data = defaultdict(lambda: pd.DataFrame(
columns = ['time', 'lat', 'lon', 'cog', 'rocog', 'sog', 'delta_sog']))
for track in tracks:
mmsi = track['mmsi']
df = pd.DataFrame({
'time': pd.to_datetime(track['time'], unit='s'),
'lat': track['lat'], 'lon': track['lon'],
'cog': track['cog'], 'sog': track['sog']
})
# Sort by time in descending order
df = df.sort_values(by='time', ascending=False).reset_index(drop=True)
# Compute the time difference in seconds
df['time_diff'] = df['time'].diff().dt.total_seconds()
# Compute RoCOG (Rate of Change of Course Over Ground)
delta_cog = (df['cog'].diff() + 180) % 360 - 180
df['rocog'] = delta_cog / df['time_diff']
# Compute Delta SOG (Rate of Change of Speed Over Ground)
df['delta_sog'] = df['sog'].diff() / df['time_diff']
# Fill NaN values (first row) and infinite values (division by zero cases)
df[['rocog', 'delta_sog']] = df[['rocog', 'delta_sog']].replace([float('inf'), float('-inf')], 0).fillna(0)
# Drop unnecessary column
df.drop(columns = ['time_diff'], inplace=True)
# Store in the dictionary
ais_data[mmsi] = df
return ais_data
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
callback=aisdb.database.sqlfcn_callbacks.in_timerange_validmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
ais_data = data2frame(tracks) # re-use previous function
# Display DataFrames
for key in ais_data.keys():
print(ais_data[key])A specific region can be queried for AIS data using aisdb.gis.Domain or one of its sub-classes to define a collection of shapely polygon features. For this example, the domain contains a single bounding box polygon derived from a longitude/latitude coordinate pair and radial distance specified in meters. If multiple features are included in the domain object, the domain boundaries will encompass the convex hull of all features.
# a circle with a 100km radius around the location point
domain = aisdb.DomainFromPoints(points=[(-69.34, 41.55)], radial_distances=[100000])
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
callback=aisdb.database.sqlfcn_callbacks.in_validmmsi_bbox,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
ais_data = data2frame(tracks) # re-use previous function
# Display DataFrames
for key in ais_data.keys():
print(ais_data[key])Additional query callbacks for filtering by region, timeframe, identifier, etc. can be found in aisdb.database.sql_query_strings and aisdb.database.sqlfcn_callbacks.
The above generator can be input into a processing function, yielding modified results. For example, to model the activity of vessels on a per-voyage or per-transit basis, each voyage is defined as a continuous vector of positions where the time between observed timestamps never exceeds 24 hours.
from datetime import timedelta
# Define a maximum time interval
maxdelta = timedelta(hours=24)
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
callback=aisdb.database.sqlfcn_callbacks.in_validmmsi_bbox,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
# Split the generated tracks into segments
track_segments = aisdb.split_timedelta(tracks, maxdelta)
ais_data = data2frame(track_segments) # re-use previous function
# Display DataFrames
for key in ais_data.keys():
print(ais_data[key])A common problem with AIS data is noise, where multiple vessels might broadcast using the same identifier (sometimes simultaneously). In such cases, AISdb can denoise the data:
(1) Denoising with Encoder: The aisdb.denoising_encoder.encode_greatcircledistance() function checks the approximate distance between each vessel’s position. It separates vectors where a vessel couldn’t reasonably travel using the most direct path, such as speeds over 50 knots.
(2) Distance and Speed Thresholds: A distance and speed threshold limits the maximum distance or time between messages that can be considered continuous.
(3) Scoring and Segment Concatenation: A score is computed for each position delta, with sequential messages nearby at shorter intervals given a higher score. This score is calculated by dividing the Haversine distance by elapsed time. Any deltas with a score not reaching the minimum threshold are considered the start of a new segment. New segments are compared to the end of existing segments with the same vessel identifier; if the score exceeds the minimum, they are concatenated. If multiple segments meet the minimum score, the new segment is concatenated to the existing segment with the highest score.
Notice that processing functions may be executed in sequence as a chain or pipeline, so after segmenting the individual voyages as shown above, results can be input into the encoder to remove noise and correct for vessels with duplicate identifiers.
distance_threshold = 20000 # the maximum allowed distance (meters) between consecutive AIS messages
speed_threshold = 50 # the maximum allowed vessel speed in consecutive AIS messages
minscore = 1e-6 # the minimum score threshold for track segment validation
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
callback=aisdb.database.sqlfcn_callbacks.in_timerange,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
# Encode the track segments to clean and validate the track data
tracks_encoded = aisdb.encode_greatcircledistance(tracks,
distance_threshold=distance_threshold,
speed_threshold=speed_threshold,
minscore=minscore)
ais_data = data2frame(tracks_encoded) # re-use previous function
# Display DataFrames
for key in ais_data.keys():
print(ais_data[key])Building on the above processing pipeline, the resulting cleaned trajectories can be geofenced and filtered for results contained by at least one domain polygon and interpolated for uniformity.
# Define a domain with a central point and corresponding radial distances
domain = aisdb.DomainFromPoints(points=[(-69.34, 41.55),], radial_distances=[100000,])
# Filter the encoded tracks to include only those within the specified domain
tracks_filtered = aisdb.track_gen.zone_mask(tracks_encoded, domain)
# Interpolate the filtered tracks with a specified time interval
tracks_interp = aisdb.interp_time(tracks_filtered, step=timedelta(minutes=15))Additional processing functions can be found in the aisdb.track_gen module.
The resulting processed voyage data can be exported in CSV format instead of being printed:
aisdb.write_csv(tracks_interp, 'ais_processed.csv')AISDB supports integrating external data sources such as bathymetric charts and other raster grids.
To determine the approximate ocean depth at each vessel position, theaisdb.webdata.bathymetry module can be used.
import aisdb
# Set the data storage directory
data_dir = './testdata/'
# Download bathymetry grid from the internet
bathy = aisdb.webdata.bathymetry.Gebco(data_dir=data_dir)
bathy.fetch_bathymetry_grid()Once the data has been downloaded, the Gebco() class may be used to append bathymetric data to tracks in the context of a TrackGen() processing pipeline like the processing functions described above.
tracks = aisdb.TrackGen(qry.gen_qry(), decimate=False)
tracks_bathymetry = bathy.merge_tracks(tracks) # merge tracks with bathymetry dataAlso, see aisdb.webdata.shore_dist.ShoreDist for determining the approximate nearest distance to shore from vessel positions.
Similarly, arbitrary raster coordinate-gridded data may be appended to vessel tracks
tracks = aisdb.TrackGen(qry.gen_qry())
raster_path './GMT_intermediate_coast_distance_01d.tif'
# Load the raster file
raster = aisdb.webdata.load_raster.RasterFile(raster_path)
# Merge the generated tracks with the raster data
tracks = raster.merge_tracks(tracks, new_track_key="coast_distance")AIS data from the database may be overlayed on a map such as the one shown above using the aisdb.web_interface.visualize() function. This function accepts a generator of track dictionaries such as those output by aisdb.track_gen.TrackGen().
from datetime import datetime, timedelta
import aisdb
from aisdb import DomainFromPoints
dbpath='example_data.db'
def color_tracks(tracks):
''' set the color of each vessel track using a color name or RGB value '''
for track in tracks:
track['color'] = 'blue' or 'rgb(0,0,255)'
yield track
# Set the start and end times for the query
start_time = datetime.strptime("2022-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
end_time = datetime.strptime("2022-01-31 00:00:00", '%Y-%m-%d %H:%M:%S')
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn,
start=start_time,
end=end_time,
callback=aisdb.database.sqlfcn_callbacks.in_timerange_validmmsi,
)
rowgen = qry.gen_qry()
# Convert queried rows to vessel trajectories
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
# Visualization
aisdb.web_interface.visualize(
tracks,
visualearth=False,
open_browser=True,
)For a complete plug-and-play solution, you may clone our Google Colab Notebook.
How to deploy your own Automatic Identification System (AIS) receiver.
In addition to utilizing AIS data provided by Spire for the Canadian coasts, you can install AIS receiver hardware to capture AIS data directly. The received data can be processed and stored in databases, which can then be used with AISdb. This approach offers additional data sources and allows users to collect and process their data (as illustrated in the pipeline below). Doing so allows you to customize your data collection efforts to meet specific needs and seamlessly integrate the data with AISdb for enhanced analysis and application. At the same time, you can share the data you collect with others.
Raspberry Pi or other computers with internet working capability
162MHz receiver, such as the Wegmatt dAISy 2 Channel Receiver
An antenna in the VHF frequency band (30MHz - 300MHz) e.g. Shakespeare QC-4 VHF Antenna
Optionally, you may want
Antenna mount
A filtered preamp, such as this one sold by Uputronics, to improve signal range and quality
An additional option includes free AIS receivers from MarrineTraffic. This option may require you to share the data with the organization to help expand its AIS-receiving network.
When setting up your antenna, place it as high as possible and far away from obstructions and other equipment as is practical.
Connect the antenna to the receiver. If using a preamp filter, connect it between the antenna and the receiver.
Connect the receiver to your Linux device via a USB cable. If using a preamp filter, power it with a USB cable.
Validate the hardware configuration
When connected via USB, the AIS receiver is typically found under /dev/ with a name beginning with ttyACM, for example /dev/ttyACM0. Ensure the device is listed in this directory.
To test the receiver, use the command sudo cat /dev/ttyACM0 to display its output. If all works as intended, you will see streams of bytes appearing on the screen.
$ sudo cat /dev/ttyACM0
!AIVDM,1,1,,A,B4eIh>@0<voAFw6HKAi7swf1lH@s,0*61
!AIVDM,1,1,,A,14eH4HwvP0sLsMFISQQ@09Vr2<0f,0*7B
!AIVDM,1,1,,A,14eGGT0301sM630IS2hUUavt2HAI,0*4A
!AIVDM,1,1,,B,14eGdb0001sM5sjIS3C5:qpt0L0G,0*0C
!AIVDM,1,1,,A,14eI3ihP14sM1PHIS0a<d?vt2L0R,0*4D
!AIVDM,1,1,,B,14eI@F@000sLtgjISe<W9S4p0D0f,0*24
!AIVDM,1,1,,B,B4eHt=@0:voCah6HRP1;?wg5oP06,0*7B
!AIVDM,1,1,,A,B4eHWD009>oAeDVHIfm87wh7kP06,0*20A visual example of the antenna hardware setup that MERIDIAN has available is as follows:
Connect the receiver to the Raspberry Pi via a USB port, and then run the configure_rpi.sh script. This will install the Rust toolchain, AISdb dispatcher, and AISdb system service (described below), allowing the receiver to start at boot.
curl --proto '=https' --tlsv1.2 https://git-dev.cs.dal.ca/meridian/aisdb/-/raw/master/configure_rpi.sh | bashInstall Raspberry Pi OS with SSH enabled: Visit https://www.raspberrypi.com/software/ to download and install the Raspberry Pi OS. If using the RPi imager, please ensure you run it as an administrator.
Connect the receiver: Attach the receiver to the Raspberry Pi using a USB cable. Then log in to the Raspberry Pi and update the system with the following command: sudo apt-get update
Install the Rust toolchain: Install the Rust toolchain on the Raspberry Pi using the following command: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh Afterward, log out and log back in to add Rust and Cargo to the system path.
Install the network client and dispatcher: (a) From crates.io, using cargo install mproxy-client(b) To install from the source, use the local path instead, e.g. cargo install --path ./dispatcher/client
Install systemd services: Set up new systemd services to run the AIS receiver and dispatcher. First, create a new text file ./ais_rcv.service with contents in the block below, replace User=ais and /home/ais with the username and home directory chosen in step 1.
[Unit]
Description="AISDB Receiver"
After=network-online.target
Documentation=https://aisdb.meridian.cs.dal.ca/doc/receiver.html
[Service]
Type=simple
User=ais # Replace with your username
ExecStart=/home/ais/.cargo/bin/mproxy-client --path /dev/ttyACM0 --server-addr 'aisdb.meridian.cs.dal.ca:9921' # Replace home directory
Restart=always
RestartSec=30
[Install]
WantedBy=default.targetThis service will broadcast receiver input downstream to aisdb.meridian.cs.dal.ca via UDP. You can add additional endpoints at this stage; for more information, see mproxy-client --help. Additional AIS networking tools, such as mproxy-forward, mproxy-server, and mproxy-reverse, are available in the ./dispatcher source directory.
Next, link and enable the service on the Raspberry Pi to ensure the receiver starts at boot:
sudo systemctl enable systemd-networkd-wait-online.service
sudo systemctl link ./ais_rcv.service
sudo systemctl daemon-reload
sudo systemctl enable ais_rcv
sudo systemctl start ais_rcvSee more examples in docker-compose.yml
For some Raspberry hardware (such as the author's Raspberry Pi 4 Model B Rev 1.5), when connecting dAISy AIS Receivers, the device file in Linux used to represent a serial communication interface is not always "/dev/ttyACM0", as used in our ./ais_rcv.service.
You can check the actual device file in use by running:
ls -l /devFor example, the author found that serial0 was linked to ttyS0 (i.e., ttyS0).
Simply changing /dev/ttyACM0 to /dev/ttyS0 may result in receiving garbled AIS signals. This is because the default baud rate settings are different. You can modify the default baud rate for ttyS0 using the following command:
stty -F /dev/ttyS0 38400 cs8 -cstopb -parenbData querying with AISdb involves setting up a connection to the database, defining query parameters, creating and executing the query, and processing the results. Following the previous tutorial, Database Loading, we set up a database connection and made simple queries and visualizations. This tutorial will dig into data query functions and parameters and show you the queries you can make with AISdb.
Data querying with AISdb includes two components: DBQuery and TrackGen. In this section, we will introduce each component with examples. Before starting data querying, please ensure you have connected to the database. If you have not done so, please follow the instructions and examples in Database Loading or Quick Start.
The DBQuery class is used to create a query object that specifies the parameters for data retrieval, including the time range, spatial domain, and any filtering callbacks. Here is an example to create a DBQuery object and use parameters to specify the time range and geographical locations:
from aisdb.database.dbqry import DBQuery
# Specify database path
dbpath = ...
# Specify constraints (optional)
start_time = ...
end_time = ...
domain = ...
# Create a query object to fetch data within time and geographical range
qry = DBQuery(
dbconn=dbconn, # Database connection object
start=start_time, # Start time for the query
end=end_time, # End time for the query
xmin=domain.boundary['xmin'], # Minimum longitude of the domain
xmax=domain.boundary['xmax'], # Maximum longitude of the domain
ymin=domain.boundary['ymin'], # Minimum latitude of the domain
ymax=domain.boundary['ymax'], # Maximum latitude of the domain
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi # Callback function to filter data
)Callback functions are used in the DBQuery class to filter data based on specific criteria. Some common callbacks include: in_bbox, in_time_bbox, valid_mmsi, and in_time_bbox_validmmsi. These callbacks ensure that the data retrieved matches the specific criteria defined in the query. Please find examples of using different callbacks with other parameters in Query types with practical examples.
For more callback functions, refer to the API documentation here: API-Doc
gen_qryThe function gen_qry is a method of the DBQuery class in AISdb. It is responsible for generating rows of data that match the query criteria specified when creating the DBQuery object. This function acts as a generator, yielding one row at a time and efficiently handling large datasets.
Two callback functions can be passed to gen_qry. They are:
crawl_dynamic : Iterates only over the position reports table. By default this is called.
crawl_dynamic_static: Iterates over both position reports and static messages tables.
After creating the DBQuery object, we can generate rows with gen_qry :
# Generate rows from the query
rowgen = qry.gen_qry(fcn=sqlfcn.crawl_dynamic_static) # callback parameter is optional
# Process the generated rows as needed
for row in rowgen:
print(row)Each row from gen_qry is a tuple or dictionary representing a record in the database.
The TrackGen class converts the generated rows from gen_qry into tracks (trajectories). It takes the row generator and, optionally, a decimate parameter to control point reduction. This conversion is essential for analyzing vessel movements, identifying patterns, and visualizing trajectories in later steps.
Following the generated rows above, here is how to use the TrackGen class:
from aisdb.track_gen import TrackGen
# Convert the generated rows into tracks
tracks = TrackGen(rowgen, decimate=False)The TrackGen class yields "tracks," which is a generator object. While iterating over tracks, each component is a dictionary representing a track for a specific vessel:
for track in tracks:
mmsi = track['mmsi']
lons = track['lon']
lats = track['lat']
speeds = track['sog']
print(f"Track for vessel MMSI {mmsi}:")
for lon, lat, speed in zip(lons[:3], lats[:3], speeds[:3]):
print(f" - Lon: {lon}, Lat: {lat}, Speed: {speed}")
break # Exit after the first trackThis is the output with our sample data:
Track for vessel MMSI 316004240:
- Lon: -63.54868698120117, Lat: 44.61691665649414, Speed: 7.199999809265137
- Lon: -63.54880905151367, Lat: 44.61708450317383, Speed: 7.099999904632568
- Lon: -63.55659866333008, Lat: 44.626953125, Speed: 1.5In this section, we will provide practical examples of the most common querying types you can make using the DBQuery class, including querying within a time range, geographical areas, and tracking vessels by MMSI. Different queries can be achieved by changing the callbacks parameters and other parameters defined in the DBQuery class. Then, we will use TrackGen to convert these query results into structured tracks for further analysis and visualization.
First, we need to import the necessary packages and prepare data:
import os
import aisdb
from datetime import datetime, timedelta
from aisdb import DBConn, DBQuery, DomainFromPoints
dbpath='YOUR_DATABASE.db' # Define the path to your databaseQuerying data within a specified time range can be done by using the in_timerange_validmmsi callback in the DBQuery class:
start_time = datetime.strptime("2018-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
end_time = datetime.strptime("2018-01-02 00:00:00", '%Y-%m-%d %H:%M:%S')
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
callback=aisdb.database.sqlfcn_callbacks.in_timerange_validmmsi,
)
rowgen = qry.gen_qry()
# Convert queried rows to vessel trajectories
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
# Visualization
aisdb.web_interface.visualize(
tracks,
domain=domain,
visualearth=True,
open_browser=True,
)This will display the queried vessel tracks (within a time range, has a valid MMSI) on the map:
You may find noise in some of the track data. In Data Cleaning, we introduced the de-noising methods in AISdb that can effectively remove unreasonable or error data points, ensuring more accurate and reliable vessel trajectories.
In practical scenarios, people may have specific points/areas of interest. DBQuery includes parameters to define a bounding box and has relevant callbacks. Let's look at an example:
domain = DomainFromPoints(points=[(-63.6, 44.6)], radial_distances=[50000]) # a circle with a 100km radius around the location point
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
callback=aisdb.database.sqlfcn_callbacks.in_validmmsi_bbox,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
aisdb.web_interface.visualize(
tracks,
domain=domain,
visualearth=True,
open_browser=True,
)This will show all the vessel tracks with valid MMSI in the defined bounding box:
In the above examples, we queried data in a time range and a geographical area. If you want to combine multiple query criteria, please check out available types of callbacks in the API Docs. In the last example above, we can simply modify the callback type to obtain vessel tracks within both the time range and geographical area:
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsiThe displayed vessel tracks:
In addition to time and location range, you can track single and multiple vessel(s) of interest by specifying their MMSI in the query. Here is an example of tracking several vessels within a time range:
import random
def assign_colors(mmsi_list):
colors = {}
for mmsi in mmsi_list:
colors[mmsi] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) # Random color in hex
return colors
# Create a function to color tracks
def color_tracks(tracks, colors):
colored_tracks = []
for track in tracks:
mmsi = track['mmsi']
color = colors.get(mmsi, "#000000") # Default to black if no color assigned
track['color'] = color
colored_tracks.append(track)
return colored_tracks
# Set the start and end times for the query
start_time = datetime.strptime("2018-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
end_time = datetime.strptime("2018-12-31 00:00:00", '%Y-%m-%d %H:%M:%S')
# Create a list of vessel MMSIs you want to track
MMSI = [636017611,636018124,636018253]
# Assign colors to each MMSI
colors = assign_colors(MMSI)
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time, mmsis = MMSI,
callback=aisdb.database.sqlfcn_callbacks.in_timerange_inmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
colored_tracks = color_tracks(tracks, colors)
# Visualizing the tracks
aisdb.web_interface.visualize(
colored_tracks,
visualearth=True,
open_browser=True,
)
This tutorial will guide you in using the AISdb package to load AIS data into a database and perform queries. We will begin with AISdb installation and environment setup, then proceed to examples of querying the loaded data and creating simple visualizations.
Preparing a Python virtual environment for AISdb is a safe practice. It allows you to manage dependencies and prevent conflicts with other projects, ensuring a clean and isolated setup for your work with AISdb. Run these commands in your terminal based on the operating system you are using:
python -m venv AISdb # create a python virtual environment
source ./AISdb/bin/activate # activate the virtual environment
pip install aisdb # from https://pypi.org/project/aisdb/python -m venv AISdb # create a virtual environment
./AISdb/Scripts/activate # activate the virtual environment
pip install aisdb # install the AISdb package using pipNow you can check your installation by running:
$ python
>>> import aisdb
>>> aisdb.__version__ # should return '1.7.0' or newerIf you're using AISdb in Jupyter Notebook, please include the following commands in your notebook cells:
# install nest-asyncio for enabling asyncio.run() in Jupyter Notebook
%pip install nest-asyncio
# Some of the systems may show the following error when running the user interface:
# urllib3 v2.0 only supports OpenSSL 1.1.1+; currently, the 'SSL' module is compiled with 'LibreSSL 2.8.3'.
# install urllib3 v1.26.6 to avoid this error
%pip install urllib3==1.26.6Then, import the required packages:
from datetime import datetime, timedelta
import os
import aisdb
import nest_asyncio
nest_asyncio.apply()This section will show you how to efficiently load AIS data into a database.
AISdb includes two database connection approaches:
SQLite database connection; and,
PostgreSQL database connection.
We are working with the SQLite database in most of the usage scenarios. Here is an example of loading data using the sample data included in the AISdb package:
# List the test data files included in the package
print(os.listdir(os.path.join(aisdb.sqlpath, '..', 'tests', 'testdata')))
# You will see the print result:
# ['test_data_20210701.csv', 'test_data_20211101.nm4', 'test_data_20211101.nm4.gz']
# Set the path for the SQLite database file to be used
test_database
# Use test_data_20210701.csv as the test data
filepaths = [os.path.join(aisdb.sqlpath, '..', 'tests', 'testdata', 'test_data_20210701.csv')]
with aisdb.DBConn(dbpath = dbpath) as dbconn:
aisdb.decode_msgs(filepaths=filepaths, dbconn=dbconn, source='TESTING')The code above decodes the AIS messages from the CSV file specified in filepaths and inserts them into the SQLite database connected via dbconn.
Following is a quick example of a query and visualization of the data we just loaded with AISdb:
start_time = datetime.strptime("2021-07-01 00:00:00", '%Y-%m-%d %H:%M:%S')
end_time = datetime.strptime("2021-07-02 00:00:00", '%Y-%m-%d %H:%M:%S')
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn,
dbpath='./AIS2.db',
callback=aisdb.database.sql_query_strings.in_timerange,
start=start_time,
end=end_time,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
if __name__ == '__main__':
aisdb.web_interface.visualize(
tracks,
visualearth=True,
open_browser=True,
)In addition to SQLite database connection, PostgreSQL is used in AISdb for its superior concurrency handling and data-sharing capabilities, making it suitable for collaborative environments and handling larger datasets efficiently. The structure and interactions with PostgreSQL are designed to provide robust and scalable solutions for AIS data storage and querying. For PostgreSQL, you need the psycopg2 library:
pip install psycopg2To connect to a PostgreSQL database, AISdb uses the PostgresDBConn class:
from aisdb.database.dbconn import PostgresDBConn
# Option 1: Using keyword arguments
dbconn = PostgresDBConn(
hostaddr='127.0.0.1', # Replace with the PostgreSQL address
port=5432, # Replace with the PostgreSQL running port
user='USERNAME', # Replace with the PostgreSQL username
password='PASSWORD', # Replace with your password
dbname='aisviz' # Replace with your database name
)
# Option 2: Using a connection string
dbconn = PostgresDBConn('postgresql://USERNAME:PASSWORD@HOST:PORT/DATABASE')After establishing a connection to the PostgreSQL database, specifying the path of the data files, and using the aisdb.decode_msgs function for data processing, the following operations will be performed in order: data files processing, table creation, data insertion, and index rebuilding.
Please pay close attention to the flags in aisdb.decode_msgs, as recent updates provide more flexibility for database configurations. These updates include support for ingesting NOAA data into the aisdb format and the option to structure tables using either the original B-Tree indexes or TimescaleDB’s structure when the extension is enabled. In particular, please take care of the following parameters:
source (str, optional)
Specifies the data source to be processed and loaded into the database.
Options: "Spire", "NOAA"/"noaa", or leave empty.
Default: empty but will progress with Spire source.
raw_insertion (bool, optional)
If False, the function will drop and rebuild indexes to speed up data loading.
Default: True.
timescaledb (bool, optional)
Set to True only if using the TimescaleDB extension in your PostgreSQL database.
Refer to the TimescaleDB documentation for proper setup and usage.
The following example demonstrates how to process and load Spire data for the entire year 2024 into an aisdb database with the TimescaleDB extension installed:
start_year = 2024
end_year = 2024
start_month = 1
end_month = 12
overall_start_time = time.time()
for year in range(start_year, end_year + 1):
for month in range(start_month, end_month + 1):
print(f'Loading {year}{month:02d}')
month_start_time = time.time()
filepaths = aisdb.glob_files(f'/slow-array/Spire/{year}{month:02d}/','.zip')
filepaths = sorted([f for f in filepaths if f'{year}{month:02d}' in f])
print(f'Number of files: {len(filepaths)}')
with aisdb.PostgresDBConn(libpq_connstring=psql_conn_string) as dbconn:
try:
aisdb.decode_msgs(filepaths,
dbconn=dbconn,
source='Spire',
verbose=True,
skip_checksum=True,
raw_insertion=True,
workers=6,
timescaledb=True,
)
except Exception as e:
print(f'Error loading {year}{month:02d}: {e}')
continueExample of performing queries and visualizations with PostgreSQL database:
from aisdb.gis import DomainFromPoints
from aisdb.database.dbqry import DBQuery
from datetime import datetime
# Define a spatial domain centered around the point (-63.6, 44.6) with a radial distance of 50000 meters.
domain = DomainFromPoints(points=[(-63.6, 44.6)], radial_distances=[50000])
# Create a query object to fetch AIS data within the specified time range and spatial domain.
qry = DBQuery(
dbconn=dbconn,
start=datetime(2023, 1, 1), end=datetime(2023, 2, 1),
xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi
)
# Generate rows from the query
rowgen = qry.gen_qry()
# Convert the generated rows into tracks
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
# Visualize the tracks on a map
aisdb.web_interface.visualize(
tracks, # The tracks (trajectories) to visualize.
domain=domain, # The spatial domain to use for the visualization.
visualearth=True, # If True, use Visual Earth for the map background.
open_browser=True # If True, automatically open the visualization in a web browser.
)Moreover, if you wish to use your own AIS data to create and process a database with AISdb, please check out our instructional guide on data processing and database creation: Using Your AIS Data.
A little bit about where we stand.
Welcome to AISdb - a comprehensive gateway for Automatic Identification System (AIS) data uses and applications. AISdb is part of the Making Vessels Tracking Data Available to Everyone (AISViz) project within the Marine Environmental Research Infrastructure for Data Integration and Application Network (MERIDIAN) initiative at Dalhousie University, designed to streamline the collection, processing, and analysis of AIS data, both in live-streaming scenarios and through historical records.
The primary features AISdb provides include:
SQL database for storing AIS position reports and vessel metadata: At the heart of AISdb is a database built on SQLite, giving users a friendly Python interface with which to interact. This interface simplifies tasks like database creation, data querying, processing, visualization, and even exporting data to CSV format for diverse uses. To cater to advanced needs, AISdb supports using Postgres, offering superior concurrency handling and data-sharing capabilities for collaborative environments.
Vessel data cleaning and trajectory modeling: AISdb includes vessel position cleaning and trajectory modeling features. This ensures that the data used for analyses is accurate and reliable, providing a solid foundation for further studies and applications.
Integration with environmental context and external metadata: One of AISdb's unique features is its ability to enrich AIS datasets with environmental context. Users can seamlessly integrate oceanographic and bathymetric data in raster formats to bring depth to their analyses — quite literally, as the tool allows for incorporating seafloor depth data underneath vessel positions. Such versatility ensures that AISdb users can merge various environmental data points with AIS information, resulting in richer, multi-faceted maritime studies.
Advanced features for maritime studies: AISdb offers network graph analysis, MMSI deduplication, interpolation, and other processing utilities. These features enable advanced data processing and analysis, supporting complex maritime studies and applications.
Python interface and machine learning for vessel behavior modeling: AISdb includes a Python interface with a RUST background that paves the way for incorporating machine learning and deep learning techniques into vessel behavior modeling in an optimized way. This aspect of AISdb enhances the reproducibility and scalability of research, be it for academic exploration or practical industry applications.
Research support: AISdb is more than just a storage and processing tool; it is a comprehensive platform designed to support research. Through a formal partnership with our research initiative (contact us for more information), academics, industry experts, and researchers can access extensive Canadian AIS data up to 100 km from the Canadian coastline. This dataset spans from January 2012 to the present and is updated monthly. AISdb offers raw and parsed data formats, eliminating preprocessing needs and streamlining AIS-related research.
The AISViz team is based on the Modeling and Analytics on Predictive Systems (MAPS) lab in collaboration with the Maritime Risk and Safety (MARS) research group at Dalhousie University. Funded by the Department of Fisheries and Oceans Canada (DFO), our mission revolves around democratizing AIS data use, making it accessible and understandable across multiple sectors, from government and academia to NGOs and the broader public. Besides, AISViz aims to introduce machine learning applications into AIS data handling of AISdb. This seeks to streamline user interactions with AIS data, enhancing the user experience by simplifying data access.
Our commitment goes beyond just providing tools. Through AISViz, we're opening doors to innovative research and policy development, targeting environmental conservation, maritime traffic management, and much more. Whether you're a professional in the field, an educator, or a maritime enthusiast, AISViz and its components, including AISdb, offer the knowledge and technology to deepen your understanding and significantly impact marine vessel tracking and the well-being of our oceans.
Ruixin Song is a research assistant in the Computer Science Department at Dalhousie University. She has an M.Sc. in Computer Science and a B.Eng. in Spatial Information and Digital Technology. Her recent work focuses on marine traffic data analysis and physics-inspired models, particularly in relation to biological invasions in the ocean. Her research interests include mobility data mining, graph neural networks, and network flow and optimization problems.
Contact: rsong@dal.ca
Gabriel Spadon is an Assistant Professor at the Faculty of Computer Science at Dalhousie University, Halifax - NS, Canada. He holds a Ph.D. and an MSc in Computer Science from the University of Sao Paulo, Sao Carlos - SP, Brazil. His research focuses on spatio-temporal analytics, time-series forecasting, and complex network mining, with a deep involvement in data science and engineering, as well as geoinformatics.
Contact: spadon@dal.ca
Ron Pelot has a Ph.D. in Management Sciences and is a Professor of Industrial Engineering at Dalhousie University. For the last 30 years, he and his team have been working on developing new software tools and analysis methods for maritime traffic safety, coastal zone security, and marine spills. Their research methods include spatial risk analysis, vessel traffic modeling, data processing, pattern analysis, location models for response resource allocation, safety analyses, and cumulative shipping impact studies.
Contact: ronald.pelot@dal.ca
Adjunct Members
Vaishnav Vaidheeswaran is a Master's student in Computer Science at Dalhousie University. He holds a B.Tech in Computer Science and Engineering and has three years of experience as a software engineer in India, working at cutting-edge startups. His ongoing work addresses incorporating spatial knowledge into trajectory forecasting models to reduce aleatoric uncertainty coming from stochastic interactions of the vessel with the environment. His research interests include large language models, graph neural networks, and reinforcement learning.
Contact: vaishnav@dal.ca
Jinkun Chen is a Ph.D. student in Computer Science at Dalhousie University, specializing in Explainable AI, Natural Language Processing (NLP), and Visualization. He earned a bachelor's degree in Computer Science with First-Class Honours from Dalhousie University. Jinkun is actively involved in research, working on advancing fairness, responsibility, trustworthiness, and explainability within Large Language Models (LLMs) and AI.
Jay Kumar has a Ph.D. in Computer Science and Technology and was a postdoctoral fellow at the Department of Industrial Engineering at Dalhousie University. He has researched AI models for time-series data for over five years, focusing on Recurrent Neural models, probabilistic modeling, and feature engineering data analytics applied to ocean traffic. His research interests include Spatio-temporal Data Mining, Stochastic Modeling, Machine Learning, and Deep Learning.
Matthew Smith has a BSc degree in Applied Computer Science from Dalhousie University and specializes in managing and analyzing vessel tracking data. He is currently a Software Engineer at Radformation in Toronto, ON. Matt served as the AIS data manager on the MERIDIAN project, where he supported research groups across Canada in accessing and utilizing AIS data. The data was used to answer a range of scientific queries, including the impact of shipping on underwater noise pollution and the danger posed to endangered marine mammals by vessel collisions.
Casey Hilliard has a BSc degree in Computer Science from Dalhousie University and was a Senior Data Manager at the Institute for Big Data Analytics. He is currently a Chief Architect at GSTS (Global Spatial Technology Solutions) in Dartmouth, NS. Casey was a long-time research support staff member at the Institute and an expert in managing and using AIS vessel-tracking data. During his time, he assisted in advancing the Institute's research projects by managing and organizing large datasets, ensuring data integrity, and facilitating data usage in research.
Stan Matwin was the director of the Institute for Big Data Analytics, Dalhousie University, Halifax, Nova Scotia; he is a professor and Canada Research Chair (Tier 1) in Interpretability for Machine Learning. He is also a distinguished professor (Emeritus) at the University of Ottawa and a full professor with the Institute of Computer Science, Polish Academy of Sciences. His main research interests include big data, text mining, machine learning, and data privacy. He is a member of the Editorial Boards of IEEE Transactions on Knowledge and Data Engineering and the Journal of Intelligent Information Systems. He received the Lifetime Achievement Award of the Canadian AI Association (CAIAC).
We are passionate about fostering a collaborative and engaged community. We welcome your questions, insights, and feedback as vital components of our continuous improvement and innovation. Should you have any inquiries about AISdb, desire further information on our research, or wish to explore potential collaborations, please don't hesitate to contact us. Staying connected with users and researchers plays a crucial role in shaping the tool's development and ensuring it meets the diverse needs of our growing user base. You can easily contact our team via email or our GitHub team platform. In addition to addressing individual queries, we are committed to organizing webinars and workshops and presenting at conferences to share knowledge, gather feedback, and widen our outreach (stay tuned for more information about these). Together, let's advance the understanding and utilization of marine data for a brighter, more informed future in ocean research and preservation.
AISdb includes a function called aisdb.gis.delta_meters that calculates the Haversine distance in meters between consecutive positions within a vessel track. This function is essential for analyzing vessel movement patterns and ensuring accurate distance calculations on the Earth's curved surface. It is also integrated into the denoising encoder, which compares distances against a threshold to aid in the data-cleaning process.
Here is an example of calculating the Haversine distance between each pair of consecutive points on a track:
import aisdb
import numpy as np
from aisdb.gis import dt_2_epoch
from datetime import datetime
y1, x1 = 44.57039426840729, -63.52931373766157
y2, x2 = 44.51304767533133, -63.494075674952555
y3, x3 = 44.458038982492134, -63.535634138077945
y4, x4 = 44.393941339104074, -63.53826396955358
y5, x5 = 44.14245580737021, -64.16608964280064
t1 = dt_2_epoch( datetime(2021, 1, 1, 1) )
t2 = dt_2_epoch( datetime(2021, 1, 1, 2) )
t3 = dt_2_epoch( datetime(2021, 1, 1, 3) )
t4 = dt_2_epoch( datetime(2021, 1, 1, 4) )
t5 = dt_2_epoch( datetime(2021, 1, 1, 7) )
# Create a sample track
tracks_short = [
dict(
lon=np.array([x1, x2, x3, x4, x5]),
lat=np.array([y1, y2, y3, y4, y5]),
time=np.array([t1, t2, t3, t4, t5]),
mmsi=123456789,
dynamic=set(['lon', 'lat', 'time']),
static=set(['mmsi'])
)
]
# Calculate the Haversine distance
for track in tracks_short:
print(aisdb.gis.delta_meters(track))[ 6961.401286 6948.59446128 7130.40147082 57279.94580704]If we visualize this track on the map, we can observe:
Extracting distance features from and to points-of-interest using raster files.
The distances of a vessel from the nearest shore, coast, and port are essential to perform particular tasks such as vessel behavior analysis, environmental monitoring, and maritime safety assessments. AISdb offers functions to acquire these distances for specific vessel positions. In this tutorial, we provide examples of calculating the distance in kilometers from shore and from the nearest port for a given point.
First, we create a sample track:
import aisdb
from aisdb.gis import dt_2_epoch
y1, x1 = 44.57039426840729, -63.52931373766157
y2, x2 = 44.51304767533133, -63.494075674952555
y3, x3 = 44.458038982492134, -63.535634138077945
y4, x4 = 44.393941339104074, -63.53826396955358
y5, x5 = 44.14245580737021, -64.16608964280064
t1 = dt_2_epoch( datetime(2021, 1, 1, 1) )
t2 = dt_2_epoch( datetime(2021, 1, 1, 2) )
t3 = dt_2_epoch( datetime(2021, 1, 1, 3) )
t4 = dt_2_epoch( datetime(2021, 1, 1, 4) )
t5 = dt_2_epoch( datetime(2021, 1, 1, 7) )
# creating a sample track
tracks_short = [
dict(
mmsi=123456789,
lon=np.array([x1, x2, x3, x4, x5]),
lat=np.array([y1, y2, y3, y4, y5]),
time=np.array([t1, t2, t3, t4, t5]),
dynamic=set(['lon', 'lat', 'time']),
static=set(['mmsi'])
)
]Here is what the sample track looks like:
The class aisdb.webdata.shore_dist.ShoreDist is used to calculate the nearest distance to shore, along with a raster file containing shore distance data. Currently, calling the get_distance function in ShoreDist will automatically download the shore distance raster file from our server. The function then merges the tracks in the provided track list, creates a new key, "km_from_shore", and stores the shore distance as the value for this key.
from aisdb.webdata.shore_dist import ShoreDist
with ShoreDist(data_dir="./testdata/") as sdist:
# Getting distance from shore for each point in the track
for track in sdist.get_distance(tracks_short):
assert 'km_from_shore' in track['dynamic']
assert 'km_from_shore' in track.keys()
print(track['km_from_shore'])[ 1 3 2 9 14]Similar to acquiring the distance from shore, CoastDist is implemented to obtain the distance between the given track positions and the coastline.
from aisdb.webdata.shore_dist import CoastDist
with CoastDist(data_dir="./testdata/") as cdist:
# Getting distance from the coast for each point in the track
for track in cdist.get_distance(tracks_short):
assert 'km_from_coast' in track['dynamic']
assert 'km_from_coast' in track.keys()
print(track['km_from_coast'])[ 1 3 2 8 13]Like the distances from the coast and shore, the aisdb.webdata.shore_dist.PortDist class determines the distance between the track positions and the nearest ports.
from aisdb.webdata.shore_dist import PortDist
with PortDist(data_dir="./testdata/") as pdist:
# Getting distance from the port for each point in the track
for track in pdist.get_distance(tracks_short):
assert 'km_from_port' in track['dynamic']
assert 'km_from_port' in track.keys()
print(track['km_from_port'])[ 4.72144175 7.47747231 4.60478449 11.5642271 28.62511253]This tutorial introduces visualization options for vessel trajectories processed using AISdb, including AISdb's integrated web interface and alternative approaches with popular Python visualization packages. Practical examples were provided for each tool, illustrating how to process and visualize AISdb tracks effectively.
AISdb provides an integrated data visualization feature through the aisdb.web_interface.visualize module, which allows users to generate interactive maps displaying vessel tracks. This built-in tool is designed for simplicity and ease of use, offering customizable visualizations directly from AIS data without requiring extensive setup.
Here is an example of using the web interface module to show queried data with colors. To display vessel tracks in a single color:
import aisdb
from datetime import datetime
from aisdb.database.dbconn import SQLiteDBConn
from aisdb import DBConn, DBQuery, DomainFromPoints
import nest_asyncio
nest_asyncio.apply()
dbpath='YOUR_DATABASE.db' # Define the path to your database
# Set the start and end times for the query
start_time = datetime.strptime("2018-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
end_time = datetime.strptime("2018-01-03 00:00:00", '%Y-%m-%d %H:%M:%S')
# Define a circle with a 100km radius around the location point
domain = DomainFromPoints(points=[(-63.6, 44.6)], radial_distances=[100000])
def color_tracks(tracks):
""" Set the color of each vessel track using a color name or RGB value. """
for track in tracks:
track['color'] = 'yellow'
yield track
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
colored_tracks = color_tracks(tracks)
# Visualization
aisdb.web_interface.visualize(
colored_tracks,
domain=domain,
visualearth=True,
open_browser=True,
)If you want to visualize vessel tracks in different colors based on MMSI, here's an example that demonstrates how to color-code tracks for easy identification:
import random
def color_tracks2(tracks):
colors = {}
for track in tracks:
mmsi = track.get('mmsi')
if mmsi not in colors:
# Assign a random color to this MMSI if not already assigned
colors[mmsi] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
track['color'] = colors[mmsi] # Set the color for the current track
yield track
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
colored_tracks = list(color_tracks2(tracks))
# Visualization
aisdb.web_interface.visualize(
colored_tracks,
domain=domain,
visualearth=True,
open_browser=True,
)Several alternative Python packages can be leveraged for users seeking more advanced or specialized visualization capabilities. For instance, Contextily, Basemap and Cartopy are excellent for creating detailed 2D plots, while Plotly offering powerful interactive graphs. Additionally, Kepler.gl caters to users needing dynamic, large-scale visualizations or 3D mapping. These alternatives allow for a deeper exploration of AIS data, offering flexibility in how data is presented and analyzed beyond the default capabilities of AISdb.
import aisdb
from datetime import datetime
from aisdb.database.dbconn import SQLiteDBConn
from aisdb import DBConn, DBQuery, DomainFromPoints
import contextily as cx
import matplotlib.pyplot as plt
import random
import nest_asyncio
nest_asyncio.apply()
dbpath='YOUR_DATABASE.db' # Define the path to your database
# Set the start and end times for the query
start_time = datetime.strptime("2018-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
end_time = datetime.strptime("2018-01-03 00:00:00", '%Y-%m-%d %H:%M:%S')
# Define a circle with a 100km radius around the location point
domain = DomainFromPoints(points=[(-63.6, 44.6)], radial_distances=[100000])
def color_tracks2(tracks):
colors = {}
for track in tracks:
mmsi = track.get('mmsi')
if mmsi not in colors:
# Assign a random color to this MMSI if not already assigned
colors[mmsi] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
track['color'] = colors[mmsi] # Set the color for the current track
yield track
def plot_tracks_with_contextily(tracks):
plt.figure(figsize=(12, 8))
for track in tracks:
plt.plot(track['lon'], track['lat'], color=track['color'], linewidth=2)
# Add basemap
cx.add_basemap(plt.gca(), crs='EPSG:4326', source=cx.providers.CartoDB.Positron)
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.title('Vessel Tracks with Basemap')
plt.show()
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
colored_tracks = list(color_tracks2(tracks))
plot_tracks_with_contextily(colored_tracks)Note: mpl_toolkits.basemap uses numpy v1, therefore, downgrade numpy to v1.26.4 to use Basemap. Else, refer to other alternatives mentioned such as Contextily!
from mpl_toolkits.basemap import Basemap
import matplotlib.pyplot as plt
def plot_tracks_with_basemap(tracks):
plt.figure(figsize=(12, 8))
# Define the geofence boundaries
llcrnrlat = 42.854329883666175 # Latitude of the southwest corner
urcrnrlat = 47.13666808816243 # Latitude of the northeast corner
llcrnrlon = -68.73998377599209 # Longitude of the southwest corner
urcrnrlon = -56.92378296577808 # Longitude of the northeast corner
# Create the Basemap object with the geofence
m = Basemap(projection='merc',
llcrnrlat=llcrnrlat, urcrnrlat=urcrnrlat,
llcrnrlon=llcrnrlon, urcrnrlon=urcrnrlon, resolution='i')
m.drawcoastlines()
m.drawcountries()
m.drawmapboundary(fill_color='aqua')
m.fillcontinents(color='lightgreen', lake_color='aqua')
for track in tracks:
lons, lats = track['lon'], track['lat']
x, y = m(lons, lats)
m.plot(x, y, color=track['color'], linewidth=2)
plt.show()
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
colored_tracks = list(color_tracks2(tracks))
plot_tracks_with_basemap(colored_tracks)import cartopy.crs as ccrs
import matplotlib.pyplot as plt
def plot_tracks_with_cartopy(tracks):
plt.figure(figsize=(12, 8))
ax = plt.axes(projection=ccrs.Mercator())
ax.coastlines()
for track in tracks:
lons, lats = track['lon'], track['lat']
ax.plot(lons, lats, transform=ccrs.PlateCarree(), color=track['color'], linewidth=2)
plt.title('AIS Tracks Visualization with Cartopy')
plt.show()
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
colored_tracks = list(color_tracks2(tracks))
plot_tracks_with_cartopy(colored_tracks)import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
def track2dataframe(tracks):
data = []
# Iterate over each track in the vessels_generator
for track in tracks:
# Unpack static information
mmsi = track['mmsi']
rot = track['rot']
maneuver = track['maneuver']
heading = track['heading']
# Unpack dynamic information
times = track['time']
lons = track['lon']
lats = track['lat']
cogs = track['cog']
sogs = track['sog']
utc_seconds = track['utc_second']
# Iterate over the dynamic arrays and create a row for each time point
for i in range(len(times)):
data.append({
'mmsi': mmsi,
'rot': rot,
'maneuver': maneuver,
'heading': heading,
'time': times[i],
'longitude': lons[i],
'latitude': lats[i],
'cog': cogs[i],
'sog': sogs[i],
'utc_second': utc_seconds[i],
})
# Convert the list of dictionaries to a pandas DataFrame
df = pd.DataFrame(data)
return df
def plotly_visualize(data, visual_type='lines'):
if (visual_type=='scatter'):
# Create a scatter plot for the vessel data points using scatter_geo
fig = px.scatter_geo(
data,
lat="latitude",
lon="longitude",
color="mmsi", # Color by vessel identifier
hover_name="mmsi",
hover_data={"time": True},
title="Vessel Data Points"
)
else:
# Create a line plot for the vessel trajectory using scatter_geo
fig = px.line_geo(
data,
lat="latitude",
lon="longitude",
color="mmsi", # Color by vessel identifier
hover_name="mmsi",
hover_data={"time": True},
)
# Set the map style and projection
fig.update_geos(
projection_type="azimuthal equal area", # Change this to 'natural earth', 'azimuthal equal area', etc.
showland=True,
landcolor="rgb(243, 243, 243)",
countrycolor="rgb(204, 204, 204)",
lonaxis=dict(range=[-68.73998377599209, -56.92378296577808]), # Longitude range (geofence)
lataxis=dict(range=[42.854329883666175, 47.13666808816243]) # Latitude range (geofence)
)
# Set the layout to focus on a specific area or zoom level
fig.update_layout(
geo=dict(
projection_type="mercator",
center={"lat": 44.5, "lon": -63.5},
),
width=900, # Increase the width of the plot
height=700, # Increase the height of the plot
)
fig.show()
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
df = track2dataframe(tracks)
plotly_visualize(df, 'lines')import pandas as pd
from keplergl import KeplerGl
def visualize_with_kepler(data, config=None):
map_1 = KeplerGl(height=600)
map_1.add_data(data=data, name="AIS Data")
map_1.save_to_html(file_name='./figure/kepler_map.html')
return map_1
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
df = track2dataframe(tracks)
map_1 = visualize_with_kepler(df)A common issue with AIS data is noise, where multiple vessels may broadcast using the same identifier simultaneously. AISdb incorporates data cleaning techniques to remove noise from vessel track data. For more details:
Denoising with Encoder: The aisdb.denoising_encoder.encode_greatcircledistance() function checks the approximate distance between each vessel’s position. It separates vectors where a vessel couldn’t reasonably travel using the most direct path, such as speeds over 50 knots.
Distance and Speed Thresholds: Distance and speed thresholds limit the maximum distance or time between messages that can be considered continuous.
Scoring and Segment Concatenation: A score is computed for each position delta, with sequential messages nearby at shorter intervals given a higher score. This score is calculated by dividing the Haversine distance by elapsed time. Any deltas with a score not reaching the minimum threshold are considered the start of a new segment. New segments are compared to the end of existing segments with the same vessel identifier; if the score exceeds the minimum, they are concatenated. If multiple segments meet the minimum score, the new segment is concatenated to the existing segment with the highest score.
Processing functions may be executed in sequence as a processing chain or pipeline, so after segmenting the individual voyages, results can be input into the encoder to remove noise and correct for vessels with duplicate identifiers effectively.
import aisdb
from datetime import datetime, timedelta
from aisdb import DBConn, DBQuery, DomainFromPoints
dbpath='YOUR_DATABASE.db' # Define the path to your database
# Set the start and end times for the query
start_time = datetime.strptime("2018-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
end_time = datetime.strptime("2018-01-02 00:00:00", '%Y-%m-%d %H:%M:%S')
# A circle with a 100km radius around the location point
domain = DomainFromPoints(points=[(-63.6, 44.6)], radial_distances=[50000])
maxdelta = timedelta(hours=24) # the maximum time interval
distance_threshold = 20000 # the maximum allowed distance (meters) between consecutive AIS messages
speed_threshold = 50 # the maximum allowed vessel speed in consecutive AIS messages
minscore = 1e-6 # the minimum score threshold for track segment validation
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
callback=aisdb.database.sqlfcn_callbacks.in_timerange_validmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
# Split the tracks into segments based on the maximum time interval
track_segments = aisdb.split_timedelta(tracks, maxdelta)
# Encode the track segments to clean and validate the track data
tracks_encoded = aisdb.encode_greatcircledistance(track_segments,
distance_threshold=distance_threshold,
speed_threshold=speed_threshold,
minscore=minscore)
tracks_colored = color_tracks(tracks_encoded)
aisdb.web_interface.visualize(
tracks_colored,
domain=domain,
visualearth=True,
open_browser=True,
)After segmentation and encoding, the tracks are shown as:
For comparison, this is a shot of tracks before cleaning:
Track interpolation with AISdb involves generating estimated positions of vessels at specific intervals when actual AIS data points are unavailable. This process is important for filling in gaps in the vessel's trajectory, which can occur due to signal loss, data filtering, or other disruptions.
In this tutorial, we introduce different types of track interpolation implemented in AISdb with usage examples.
First, we defined functions to transform and visualize the track data (a generator object), with options to view the data points or the tracks:
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
def track2dataframe(tracks):
data = []
for track in tracks:
times = track['time']
mmsi = track['mmsi']
lons = track['lon']
lats = track['lat']
# Iterate over the dynamic arrays and create a row for each time point
for i in range(len(times)):
data.append({
'mmsi': mmsi,
'time': times[i],
'longitude': lons[i],
'latitude': lats[i],
})
return pd.DataFrame(data)
def plotly_visualize(data, visual_type='lines'):
if (visual_type=='scatter'):
# Create a scatter plot for the vessel data points using scatter_geo
fig = px.scatter_geo(
data,
lat="latitude",
lon="longitude",
color="mmsi", # Color by vessel identifier
hover_name="mmsi",
hover_data={"time": True},
title="Vessel Data Points"
)
else:
# Create a line plot for the vessel trajectory using line_geo
fig = px.line_geo(
data,
lat="latitude",
lon="longitude",
color="mmsi", # Color by vessel identifier
hover_name="mmsi",
hover_data={"time": True},
title="Vessel Trajectory"
)
# Set the map style and projection
fig.update_geos(
projection_type="azimuthal equal area", # Change this to 'natural earth', 'azimuthal equal area', etc.
showland=True,
landcolor="rgb(243, 243, 243)",
countrycolor="rgb(204, 204, 204)",
)
# Set the layout to focus on a specific area or zoom level
fig.update_layout(
geo=dict(
projection_type="azimuthal equal area",
),
width=1200, # Increase the width of the plot
height=800, # Increase the height of the plot
)
fig.show()We will use an actual track retrieved from the database for the examples in this tutorial and interpolate additional data points based on this track. The visualization will show the original track data points:
import aisdb
import numpy as np
import nest_asyncio
from aisdb import DBConn, DBQuery
from datetime import timedelta, datetime
nest_asyncio.apply()
dbpath='YOUR_DATABASE.db' # Define the path to your database
MMSI = 636017611 # MMSI of the vessel
# Set the start and end times for the query
start_time = datetime.strptime("2018-03-10 00:00:00", '%Y-%m-%d %H:%M:%S')
end_time = datetime.strptime("2018-03-31 00:00:00", '%Y-%m-%d %H:%M:%S')
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time, mmsi = MMSI,
callback=aisdb.database.sqlfcn_callbacks.in_timerange_hasmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
# Visualize the original track data points
df = track2dataframe(tracks)
plotly_visualize(df, 'scatter')Linear interpolation estimates the vessel's position by drawing a straight line between two known points and calculating the positions at intermediate times. It is simple, fast, and straightforward but may not accurately represent complex movements.
This method estimates the position of a vessel at regular time intervals (e.g., every 10 minutes). To perform linear interpolation with an equal time window on the track defined above:
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time, mmsi = MMSI,
callback=aisdb.database.sqlfcn_callbacks.in_timerange_hasmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
tracks__ = aisdb.interp.interp_time(tracks, timedelta(minutes=10))
df = track2dataframe(tracks__)
plotly_visualize(df)This method estimates the position of a vessel at regular spatial intervals (e.g., every 1 km along its path). To perform linear interpolation with equal distance intervals on the pseudo track defined above:
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time, mmsi = MMSI,
callback=aisdb.database.sqlfcn_callbacks.in_timerange_hasmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
tracks__ = aisdb.interp.interp_spacing(spacing=500, tracks=tracks)
# Visualizing the tracks
df = track2dataframe(tracks__)
plotly_visualize(df)This method estimates the positions of a vessel along a curved path using the principles of geometry, particularly involving great-circle routes.
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time, mmsi = MMSI,
callback=aisdb.database.sqlfcn_callbacks.in_timerange_hasmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
tracks__ = aisdb.interp.geo_interp_time(tracks, timedelta(minutes=10))
df = track2dataframe(tracks__)
plotly_visualize(df)Given a set of data points, cubic spline interpolation fits a smooth curve through these points. The curve is represented as a series of cubic polynomials between each pair of data points. Each polynomial ensures a smooth curve at the data points (i.e., the first and second derivatives are continuous).
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time, mmsi = MMSI,
callback=aisdb.database.sqlfcn_callbacks.in_timerange_hasmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
tracks__ = aisdb.interp.interp_cubic_spline(tracks, timedelta(minutes=10))
# Visualizing the tracks
df = track2dataframe(tracks__)
plotly_visualize(df)In addition to the standard interpolation methods provided by AISdb, users can implement other interpolation techniques tailored to their specific analytical needs. For instance, B-spline (Basis Spline) interpolation is a mathematical technique that creates a smooth curve through data points. This smoothness is important in trajectory analysis as it avoids sharp, unrealistic turns and maintains a natural flow.
Here is an implementation and example of using B-splines interpolation:
import numpy as np
from scipy.interpolate import splrep, splev
def bspline_interpolation(track, key, intervals):
"""
Perform B-Spline interpolation for a specific key on the track data.
Parameters:
- track: Dictionary containing vessel track data (time, lat, lon, etc.).
- key: The dynamic key (e.g., 'lat', 'lon') for which interpolation is performed.
- intervals: The equal time or distance intervals at which interpolation is required.
Returns:
- Interpolated values for the specified key.
"""
# Get time and the key values (e.g., lat/lon) for interpolation
times = track['time']
values = track[key]
# Create the B-Spline representation of the curve
tck = splrep(times, values, s=0) # s=0 means no smoothing, exact fit to data
# Interpolate the values at the given intervals
interpolated_values = splev(intervals, tck)
return interpolated_values
def interp_bspline(tracks, step=1000):
"""
Perform B-Spline interpolation on vessel trajectory data at equal time intervals.
Parameters:
- tracks: List of vessel track dictionaries.
- step: Step for interpolation (can be time or distance-based).
Yields:
- Dictionary containing interpolated lat and lon values for each track.
"""
for track in tracks:
if len(track['time']) <= 1:
warnings.warn('Cannot interpolate track of length 1, skipping...')
continue
# Generate equal time intervals based on the first and last time points
intervals = np.arange(track['time'][0], track['time'][-1], step)
# Perform B-Spline interpolation for lat and lon
interpolated_lat = bspline_interpolation(track, 'lat', intervals)
interpolated_lon = bspline_interpolation(track, 'lon', intervals)
# Yield interpolated track
itr = dict(
mmsi=track['mmsi'],
lat=interpolated_lat,
lon=interpolated_lon,
time=intervals # Including interpolated time intervals for reference
)
yield itrThen, we can apply the function just implemented on the vessel tracks generator:
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn, start=start_time, end=end_time, mmsi = MMSI,
callback=aisdb.database.sqlfcn_callbacks.in_timerange_hasmmsi,
)
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
tracks__ = interp_bspline(tracks)
# Visualizing the tracks
df = track2dataframe(tracks__)
plotly_visualize(df)The visualization of the interpolation shows as:
Automatic Identification System (AIS) data provides a wealth of insights into maritime activities, including vessel movements and traffic patterns. However, the massive volume of AIS data, often consisting of millions or even billions of GPS position points, can be overwhelming somehow. Processing and visualizing this raw data directly can be computationally expensive, slow, and difficult to interpret.
This is where the AISdb's decimation comes into play - it helps users efficiently reduce data clutter, making it easier to extract and focus on the most relevant information.
Decimation, in simple terms, means reducing the number of data points. When applied to AIS tracks, it involves selectively removing GPS points from a vessel’s trajectory while preserving its overall shape and key characteristics. Rather than processing every recorded position, decimation algorithms identify and retain the most relevant points, optimizing data efficiency without significant loss of accuracy.
Think of it like simplifying a drawing: instead of using thousands of tiny dots to represent a complex image, you can use fewer, strategically chosen points to capture its essence. Similarly, decimation ensures that a vessel’s path with fewer points while maintaining its core trajectory, making analysis and visualization more efficient.
There are several key benefits for using decimation techniques when working with AIS data:
Improved Performance and Efficiency: Reducing the number of data points can dramatically decrease the computational load, enabling faster analyses, quicker visualizations, and more effective workflow, especially when dealing with large datasets.
Clearer Visualizations: Dense tracks can clutter visualizations and make it difficult to interpret the data. Decimation simplifies the tracks, emphasizing on significant movements and patterns for more intuitive analysis.
Noise Reduction: While decimation is not designed as a noise removal technique, it can help smooth out minor inaccuracies and high-frequency fluctuations from raw GPS data. This can be useful for focusing on broader trends and vessel movements.
simplify_linestring_idx()In AISDB, TrackGen() method includes adecimate parameter that, when set as True, triggers the simplify_linestring_idx(x, y, precision)function. This function uses the Visvalingam-Whyatt algorithm to simplify vessel tracks while preserving key trajectory details.
The Visvalingam-Whyatt algorithm is an approach to line simplification. It works by removing points that contribute the least to the overall shape of the line. Here’s how it works:
The algorithm measures the importance of a point by calculating the area of the triangle formed by that point and its adjacent points.
Points on relatively straight segments form smaller triangles, meaning they’re less important in defining the shape.
Points at curves and corners form larger triangles, signaling that they’re crucial for maintaining the line’s characteristic form.
The algorithm iteratively removes the points with the smallest triangle areas until the desired level of simplification is achieved. In AISdb, this process is controlled by the decimate parameter in the TrackGen() method.
TrackGen(...,decimate = True) with AISDB TracksBelow is a conceptual Python example that demonstrates how to apply decimation to AIS tracks:
import aisdb
import numpy as np
# Assuming you have a database connection and domain set up as described
with aisdb.SQLiteDBConn(dbpath='your_ais_database.db') as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn,
start='2023-01-01', end='2023-01-02', # Example time range
xmin=-10, xmax=0, ymin=40, ymax=50, # Example bounding box
callback=aisdb.database.sqlfcn_callbacks.in_validmmsi_bbox,
)
simplified_tracks = aisdb.TrackGen(qry.gen_qry(), decimate=True) # Generate initial tracks
for segment in simplified_tracks:
print(f"Simplified track for MMSI: {segment['mmsi']}, Points: {segment['lon'].size}")
simplify_linestring_idx() with AISDB TracksTo get more control over the precision for decimation, use function: simplify_linestring_idx in AISdb.
import aisdb
import numpy as np
# Assuming you have a database connection and domain set up as described
with aisdb.SQLiteDBConn(dbpath='your_ais_database.db') as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn,
start='2023-01-01', end='2023-01-02', # Example time range
xmin=-10, xmax=0, ymin=40, ymax=50, # Example bounding box
callback=aisdb.database.sqlfcn_callbacks.in_validmmsi_bbox,
)
tracks = aisdb.TrackGen(qry.gen_qry(), decimate=False) # Generate initial tracks
simplified_tracks = []
for track in tracks:
if track['lon'].size > 2: # Ensure track has enough points
# Apply decimation using simplify_linestring_idx
simplified_indices = aisdb.track_gen.simplify_linestring_idx(
track['lon'], track['lat'], precision=0.01 # Example precision
)
# Extract simplified track points
simplified_track = {
'mmsi': track['mmsi'],
'time': track['time'][simplified_indices],
'lon': track['lon'][simplified_indices],
'lat': track['lat'][simplified_indices],
# Carry over other relevant track attributes as needed
}
simplified_tracks.append(simplified_track)
else:
simplified_tracks.append(track) # Keep tracks with few points as is
# Now 'simplified_tracks' contains decimated tracks ready for further analysis
for segment in simplified_tracks:
print(f"Simplified track for MMSI: {segment['mmsi']}, Points: {segment['lon'].size}")
Precision: The precision parameter controls the level of simplification. A smaller value (e.g., 0.001) results in more retained points and higher fidelity, while a larger value (e.g., 0.1) simplifies the track further with fewer points.
x, y: These are NumPy arrays representing the longitude and latitude coordinates of the track points.
TrackGen Integration: Decimation is applied after generating tracks with aisdb.TrackGen, followed by the application of simplify_linestring_idx() to each track individually.
Iterative Refinement: Decimation is often an iterative process. You may need to visualize the decimated tracks, assess the level of simplification, and adjust the precision to balance simplification with data fidelity.
Decimation is a powerful tool for simplifying and decluttering AIS data. By intelligently reducing the data’s complexity, AISDB’s simplify_linestring_idx() and TrackGen()allows you to process data more efficiently, create clearer visualizations, and gain deeper insights from your maritime data. Experiment with different precision values, and discover how “less” data can lead to “more” meaningful results in your AIS analysis workflows!
Amigo D, Sánchez Pedroche D, García J, Molina JM. Review and classification of trajectory summarisation algorithms: From compression to segmentation. International Journal of Distributed Sensor Networks. 2021;17(10). doi:10.1177/15501477211050729
This section demonstrates integrating AIS data with external bathymetric data to enrich our analysis. In the following example, we identified all vessels within a 500-kilometer radius around the central area of Halifax, Canada, on January 1, 2018.
First, we imported the necessary packages and prepared the bathymetry data. It’s important to note that the downloaded bathymetric data is divided into eight segments, organized by latitude and longitude. In a later step, you will need to select the appropriate bathymetric raster file based on the geographical region covered by your vessel track data.
import os
import aisdb
import nest_asyncio
from datetime import datetime
from aisdb.database.dbconn import SQLiteDBConn
from aisdb import DBConn, DBQuery, DomainFromPoints
nest_asyncio.apply()
# set the path to the data storage directory
bathymetry_data_dir = "./bathymetry_data/"
# check if the directory exists
if not os.path.exists(bathymetry_data_dir):
os.makedirs(bathymetry_data_dir)
# check if the directory is empty
if os.listdir(bathymetry_data_dir) == []:
# download the bathymetry data
bathy = aisdb.webdata.bathymetry.Gebco(data_dir=bathymetry_data_dir)
bathy.fetch_bathymetry_grid()
else:
print("Bathymetry data already exists.")We defined a coloring criterion to classify tracks based on their average depths relative to the bathymetry. Tracks that traverse shallow waters with an average depth of less than 100 meters are colored in yellow. Those spanning depths between 100 and 1,000 meters are represented in orange, indicating a transition to deeper waters. As the depth increases, tracks reaching up to 20 kilometers are marked pink. The deepest tracks, descending beyond 20 kilometers, are distinctly colored in red.
def add_color(tracks):
for track in tracks:
# Calculate the average coastal distance
avg_coast_distance = sum(abs(dist) for dist in track['coast_distance']) / len(track['coast_distance'])
# Determine the color based on the average coastal distance
if avg_coast_distance <= 100:
track['color'] = "yellow"
elif avg_coast_distance <= 1000:
track['color'] = "orange"
elif avg_coast_distance <= 20000:
track['color'] = "pink"
else:
track['color'] = "red"
yield trackNext, we query the AIS data to be integrated with the bathymetric raster file and apply the coloring function to mark the tracks based on their average depths relative to the bathymetry.
dbpath = 'YOUR_DATABASE.db' # define the path to your database
end_time = datetime.strptime("2018-01-02 00:00:00", '%Y-%m-%d %H:%M:%S')
start_time = datetime.strptime("2018-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
domain = DomainFromPoints(points=[(-63.6, 44.6)], radial_distances=[500000])
with SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = DBQuery(
dbconn=dbconn, start=start_time, end=end_time,
xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
)
tracks = aisdb.track_gen.TrackGen(qry.gen_qry(), decimate=False)
# Merge the tracks with the raster data
raster_path = "./bathymetry_data/gebco_2022_n90.0_s0.0_w-90.0_e0.0.tif"
raster = aisdb.webdata.load_raster.RasterFile(raster_path)
tracks_raster = raster.merge_tracks(tracks, new_track_key="coast_distance")
# Add color to the tracks
tracks_colored = add_color(tracks_raster)
if __name__ == '__main__':
aisdb.web_interface.visualize(
tracks_colored,
visualearth=True,
open_browser=True,
)The integrated results are color-coded and can be visualized as shown below:
Example of using bathymetry data to color-code vessel tracks based on their average depth:
Yellow: Tracks with an average depth of less than 100 meters (shallow waters).
Orange: Tracks with an average depth between 100 and 1,000 meters (transition to deeper waters).
Pink: Tracks with an average depth between 1,000 and 20,000 meters (deeper waters).
Red: Tracks with an average depth greater than 20,000 meters (deepest waters).
This tutorial introduces integrating weather data from GRIB files with AIS data for enhanced vessel tracking analysis. Practical examples are provided below illustrating how to integrate AISdb tracks with the weather data in GRIB files.
To directly work with the jupyter notebook, click here: https://github.com/AISViz/AISdb/blob/master/examples/weather.ipynb
Before diving in, users are expected to have the following set-up: a Copernicus CDS account (free) to access ERA5 data, which can be obtained through the ECMWF-Signup, and AISdb set up either locally or remotely. Refer to the AISDB-Installation documentation for detailed instructions and configuration options.
Once you have a Copernicus CDS account and AISdb installed, you can download weather data in GRIB format directly from the CDS and use AISdb to extract specific variables from those files.
AISdb supports both zipped (.zip) and uncompressed GRIB (.grib) files. These files should be named using the yyyy-mm format (e.g., 2023-03.grib or 2023-03.zip) and placed in a folder such as /home/CanadaV2.
With the WeatherDataStore class in AISdb, you can specify:
The desired weather variables (e.g., '10u', '10v' ,tp ),
A date range (e.g., August 1 to August 30, 2023),
And the directory where the GRIB files are stored.
To automatically download GRIB files from the Copernicus Climate Data Store (CDS), AISdb provides a convenient option using the WeatherDataStore class.
Simply set the parameter download_from_cds=True, and specify the required weather variable short names, date range, target area, and output path.
To extract weather data for specific latitude, longitude, and timestamp values, call the yield_tracks_with_weather() method. This returns a dictionary containing the requested weather variables for each location-time pair.
from aisdb.weather.data_store import WeatherDataStore # for weather
# ...some code before...
with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
qry = aisdb.DBQuery(
dbconn=dbconn,
start=start_time,
end=end_time,
callback=aisdb.database.sqlfcn_callbacks.in_timerange_validmmsi,
)
rowgen = qry.gen_qry()
# Convert queried rows to vessel trajectories
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
# Mention the short_names for the required weather data from the grib file
weather_data_store = WeatherDataStore(short_names = ['10u','10v','tp'], start = start_time,end = end_time,weather_data_path = ".",download_from_cds = False ,area = [-70, 45, -58, 53])
tracks = weather_data_store.yield_tracks_with_weather(tracks)
for t in tracks:
print(f"'u-component' 10m wind for:\nlat: {t['lat'][0]} \nlon: {t['lon'][0]} \ntime: {t['time'][0]} \nis {t['weather_data']['u10'][0]} m/s")
break
weather_data_store.close()Output:
'u-component' 10m wind for:
lat: 50.003334045410156
lon: -66.76000213623047
time: 1690858823
is 1.9680767059326172 m/sshort_names?In ECMWF (European Centre for Medium-Range Weather Forecasts) terminology, a "short name" is a concise, often abbreviated, identifier used to represent a specific meteorological parameter or variable within their data files (like GRIB files). It typically refers to a concise identifier used for climate and weather data variables. For example, "t2m" is a short name for "2-meter temperature".
For a list of short names for different weather components, refer to: https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation#heading-Parameterlistings
Let's work on an example where we retrieve AIS tracks from AISdb , call WeatherDataStore to add weather data to the tracks.
import aisdb
import nest_asyncio
from aisdb import DBQuery
from aisdb.database.dbconn import PostgresDBConn
from datetime import datetime
from PIL import ImageFile
from aisdb.weather.data_store import WeatherDataStore # for weather
nest_asyncio.apply()
ImageFile.LOAD_TRUNCATED_IMAGES = True# >>> PostgreSQL Information <<<
db_user='' # DB User
db_dbname='aisviz' # DB Schema
db_password='' # DB Password
db_hostaddr='127.0.0.1' # DB Host address
dbconn = PostgresDBConn(
port=5555, # PostgreSQL port
user=db_user, # PostgreSQL username
dbname=db_dbname, # PostgreSQL database
host=db_hostaddr, # PostgreSQL address
password=db_password, # PostgreSQL password
)Specify the region and duration for which you wish the tracks to be generated. The TrackGen returns a generator containing all the dynamic and static column values of AIS data.
xmin, ymin, xmax, ymax = -70, 45, -58, 53
gulf_bbox = [xmin, xmax, ymin, ymax]
start_time = datetime(2023, 8, 1)
end_time = datetime(2023, 8, 30)
qry = DBQuery(
dbconn=dbconn,
start=start_time, end=end_time,
xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax,
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi
)
ais_tracks = []
rowgen = qry.gen_qry()
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=True")Next, to merge the tracks with the weather data, we need to use the WeatherDataStore class from aisdb.weather.era5. By calling the WeatherDataStore with the required weather components (using their short names) and providing the path to your GRIB file containing the weather dataset, it will open the file and return an object through which you can perform further operations.
weather_data_store = WeatherDataStore(
short_names=['10u', '10v', 'tp'], # U & V wind components, total precipitation
start=start_time, # e.g., datetime(2023, 8, 1)
end=end_time, # e.g., datetime(2023, 8, 30)
weather_data_path=".", # Local folder to store downloaded GRIBs
download_from_cds=True, # Enable download from CDS
area=[-70, 45, -58, 53] # [west, north, east, south] in degrees
)Here, 10v and 10u are the 10 metre U wind component and the 10 metre V wind component.
By using the method weather_data_store.yield_tracks_with_weather(tracks), the tacks are concatenated with weather data.
Example usage:
tracks_with_weather = weather_data_store.yield_tracks_with_weather(tracks)
for track in tracks_with_weather :
print(f"'u-component' 10m wind for:\nlat: {track['lat'][0]} \nlon: {track['lon'][0]} \ntime: {track['time'][0]} \nis {track['weather_data']['u10'][0]} m/s")
break
weather_data_store.close() # gracefully closes the opened GRIB file.By integrating weather data with AIS data , we can study the patterns of ship movement in relation to weather conditions. This integration allows us to analyze how factors such as wind, sea surface temperature, and atmospheric pressure influence vessel trajectories. By merging dynamic AIS data with detailed climate variables, we can gain deeper insights into the environmental factors affecting shipping routes and optimize vessel navigation for better efficiency and safety.
The Automatic Identification System (AIS) is a standardized and unencrypted self-reporting maritime surveillance system.
The protocol operates by transmitting one or more of 27 message types from an AIS transponder onboard a vessel at fixed time intervals. These intervals depend on the vessel’s status—stationary vessels (anchored or moored) transmit every 3 minutes, while fast-moving vessels transmit every 2 seconds.
These VHF radio messages are sent from the vessel’s transponder and received by either satellite or ground-based stations, enabling more detailed monitoring and analysis of maritime traffic.
Dynamic messages convey the vessel's real-time status, which can vary between transmissions. These include data such as Speed Over Ground (SOG), Course Over Ground (COG), Rate of Turn (ROT), and the vessel’s current position (latitude and longitude).
Static messages, on the other hand, provide information that remains constant over time. This includes details like the Maritime Mobile Service Identity (MMSI), International Maritime Organization (IMO) number, vessel name, call sign, type, dimensions, and intended destination.
Signals from vessels are lost.
Terrestrial base stations are limited by their physical range, while satellite AIS receivers are limited based on their position globally.
To learn more about AIS, refer to: https://www.marinelink.com/news/definitive-guide-ais418266
Brousseau, M. (2022). A comprehensive analysis and novel methods for on-purpose AIS switch-off detection [Master’s thesis, Dalhousie University]. DalSpace. http://hdl.handle.net/10222/81160
Kazim, T. (2016, November 14). A definitive guide to AIS. MarineLink. Retrieved May 14, 2025, from https://www.marinelink.com/news/definitive-guide-ais418266
Trajectory Forecasting with Gate Recurrent Units AutoEncoders
By the end of this tutorial, you will understand the benefits of using teacher forcing to improve model accuracy, as well as other tweaks to enhance forecasting capabilities. We'll use AutoEncoders, neural networks that learn compressed data representations, to achieve this.
We will guide you through preparing AIS data for training an AutoEncoder, setting up layers, compiling the model, and defining the training process with teacher forcing.
Given the complexity of this task, we will revisit it to explore the benefits of teacher forcing, a technique that can improve sequence-to-sequence learning in neural networks.
This tutorial focuses on Trajectory Forecasting, which predicts an object's future path based on past positions. We will work with AIS messages, a type of temporal data that provides information about vessels' location, speed, and heading over time.
Automatic Identification System (AIS) messages broadcast essential ship information such as position, speed, and course. The temporal nature of these messages is pivotal for our tutorial, where we'll train an auto-encoder neural network for trajectory forecasting. This task involves predicting a ship's future path based on its past AIS messages, making it ideal for auto-encoders, which are optimized for learning patterns in sequential data.
def qry_database(dbname, start_time, stop_time):
d_threshold = 200000 # max distance (in meters) between two messages before assuming separate tracks
s_threshold = 50 # max speed (in knots) between two AIS messages before splitting tracks
t_threshold = timedelta(hours=24) # max time (in hours) between messages of a track
try:
with aisdb.DBConn() as dbconn:
tracks = aisdb.TrackGen(
aisdb.DBQuery(
dbconn=dbconn, dbpath=os.path.join(ROOT, PATH, dbname),
callback=aisdb.database.sql_query_strings.in_timerange,
start=start_time, end=stop_time).gen_qry(),
decimate=False) # trajectory compression
tracks = aisdb.split_timedelta(tracks, t_threshold) # split trajectories by time without AIS message transmission
tracks = aisdb.encode_greatcircledistance(tracks, distance_threshold=d_threshold, speed_threshold=s_threshold)
tracks = aisdb.interp_time(tracks, step=timedelta(minutes=5)) # interpolate every n-minutes
# tracks = vessel_info(tracks, dbconn=dbconn) # scrapes vessel metadata
return list(tracks) # list of segmented pre-processed tracks
except SyntaxError as e: return [] # no results for queryFor querying the entire database at once, use the following code:
def get_tracks(dbname, start_ddmmyyyy, stop_ddmmyyyy):
stop_time = datetime.strptime(stop_ddmmyyyy, "%d%m%Y")
start_time = datetime.strptime(start_ddmmyyyy, "%d%m%Y")
# returns a list with all tracks from AISdb
return qry_database(dbname, start_time, stop_time)For querying the database in batches of hours, use the following code:
def batch_tracks(dbname, start_ddmmyyyy, stop_ddmmyyyy, hours2batch):
stop_time = datetime.strptime(stop_ddmmyyyy, "%d%m%Y")
start_time = datetime.strptime(start_ddmmyyyy, "%d%m%Y")
# yields a list of results every delta_time iterations
delta_time = timedelta(hours=hours2batch)
anchor_time, next_time = start_time, start_time + delta_time
while next_time < stop_time:
yield qry_database(dbname, anchor_time, next_time)
anchor_time = next_time
next_time += delta_time
# yields a list of final results (if any)
yield qry_database(dbname, anchor_time, stop_time)Several functions were defined using AISdb, an AIS framework developed by MERIDIAN at Dalhousie University, to efficiently extract AIS messages from SQLite databases. AISdb is designed for effective data storage, retrieval, and preparation for AIS-related tasks. It provides comprehensive tools for interacting with AIS data, including APIs for data reading and writing, parsing AIS messages, and performing various data transformations.
Our next step is to create a coverage map of Atlantic Canada to visualize our dataset. We will include a 100km radius circle on the map to show the areas of the ocean where vessels can send AIS messages. Although overlapping circles may contain duplicate data from the same MMSI, we have already eliminated those from our dataset. However, messages might still appear incorrectly in inland areas.
# Create the map with specific latitude and longitude limits and a Mercator projection
m = Basemap(llcrnrlat=42, urcrnrlat=52, llcrnrlon=-70, urcrnrlon=-50, projection="merc", resolution="h")
# Draw state, country, coastline borders, and counties
m.drawstates(0.5)
m.drawcountries(0.5)
m.drawcoastlines(0.5)
m.drawcounties(color="gray", linewidth=0.5)
# Fill continents and oceans
m.fillcontinents(color="tan", lake_color="#91A3B0")
m.drawmapboundary(fill_color="#91A3B0")
coordinates = [
(51.26912, -57.53759), (48.92733, -58.87786),
(47.49307, -59.41325), (42.54760, -62.17624),
(43.21702, -60.49943), (44.14955, -60.59600),
(45.42599, -59.76398), (46.99134, -60.02403)]
# Draw 100km-radius circles
for lat, lon in coordinates:
radius_in_degrees = 100 / (111.32 * np.cos(np.deg2rad(lat)))
m.tissot(lon, lat, radius_in_degrees, 100, facecolor="r", edgecolor="k", alpha=0.5)
# Add text annotation with an arrow pointing to the circle
plt.annotate("AIS Coverage", xy=m(lon, lat), xytext=(40, -40),
textcoords="offset points", ha="left", va="bottom", fontweight="bold",
arrowprops=dict(arrowstyle="->", color="k", alpha=0.7, lw=2.5))
# Add labels
ocean_labels = {
"Atlantic Ocean": [(-59, 44), 16],
"Gulf of\nMaine": [(-67, 44.5), 12],
"Gulf of St. Lawrence": [(-64.5, 48.5), 11],
}
for label, (coords, fontsize) in ocean_labels.items():
plt.annotate(label, xy=m(*coords), xytext=(6, 6), textcoords="offset points",
fontsize=fontsize, color="#DBE2E9", fontweight="bold")
# Add a scale in kilometers
m.drawmapscale(-67.5, 42.7, -67.5, 42.7, 500, barstyle="fancy", fontsize=8, units="km", labelstyle="simple")
# Set the map title
_ = plt.title("100km-AIS radius-coverage on Atlantic Canada", fontweight="bold")
# The circle diameter is 200km, and it does not match the km scale (approximation)Loading a shapefile to help us define whether a vessel is on land or in water during the trajectory:
land_polygons = gpd.read_file(os.path.join(ROOT, SHAPES, "ne_50m_land.shp"))Check if a given coordinate (latitude, longitude) is on land:
def is_on_land(lat, lon, land_polygons):
return land_polygons.contains(Point(lon, lat)).any()Check if any coordinate of a track is on land:
def is_track_on_land(track, land_polygons):
for lat, lon in zip(track["lat"], track["lon"]):
if is_on_land(lat, lon, land_polygons):
return True
return FalseFilter out tracks with any point on land for a given MMSI:
def process_mmsi(item, polygons):
mmsi, tracks = item
filtered_tracks = [t for t in tracks if not is_track_on_land(t, polygons)]
return mmsi, filtered_tracks, len(tracks)Use a ThreadPoolExecutor to parallelize the processing of MMSIs:
def process_voyages(voyages, land_polygons):
# Tracking progress with TQDM
def process_mmsi_callback(result, progress_bar):
mmsi, filtered_tracks, _ = result
voyages[mmsi] = filtered_tracks
progress_bar.update(1)
# Initialize the progress bar with the total number of MMSIs
progress_bar = tqdm(total=len(voyages), desc="MMSIs processed")
with ThreadPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor:
# Submit all MMSIs for processing
futures = {executor.submit(process_mmsi, item, land_polygons): item for item in voyages.items()}
# Retrieve the results as they become available and update the Voyages dictionary
for future in as_completed(futures):
result = future.result()
process_mmsi_callback(result, progress_bar)
# Close the progress bar after processing the complete
progress_bar.close()
return voyages
file_name = "curated-ais.pkl"
full_path = os.path.join(ROOT, ESRF, file_name)
if not os.path.exists(full_path):
voyages = process_voyages(voyages, land_polygons)
pkl.dump(voyages, open(full_path, "wb"))
else: voyages = pkl.load(open(full_path, "rb"))Count the number of segments per MMSI after removing duplicates and inaccurate track segments:
voyages_counts = {k: len(voyages[k]) for k in voyages.keys()}In this analysis, we observe that most MMSIs in the dataset exhibit between 1 and 49 segments during the search period within AISdb. However, a minor fraction of vessels have significantly more segments, with some reaching up to 176. Efficient processing involves categorizing the data by MMSI instead of merely considering its volume. This method allows us to better evaluate the model's ability to discern various movement behaviors from both the same vessel and different ones.
def plot_voyage_segments_distribution(voyages_counts, bar_color="#ba1644"):
data = pd.DataFrame({"Segments": list(voyages_counts.values())})
return alt.Chart(data).mark_bar(color=bar_color).encode(
alt.X("Segments:Q", bin=alt.Bin(maxbins=90), title="Segments"),
alt.Y("count(Segments):Q", title="Count", scale=alt.Scale(type="log")))\
.properties(title="Distribution of Voyage Segments", width=600, height=400)\
.configure_axisX(titleFontSize=16).configure_axisY(titleFontSize=16)\
.configure_title(fontSize=18).configure_view(strokeOpacity=0)
alt.data_transformers.enable("default", max_rows=None)
plot_voyage_segments_distribution(voyages_counts).display()To prevent our model from favoring shorter trajectories, we need a balanced mix of short-term and long-term voyages in the training and test sets. We'll categorize trajectories with 30 or more segments as long-term and those with fewer segments as short-term. Implement an 80-20 split strategy to ensure an equitable distribution of both types in the datasets.
long_term_voyages, short_term_voyages = [], []
# Separating voyages
for k in voyages_counts:
if voyages_counts[k] < 30:
short_term_voyages.append(k)
else: long_term_voyages.append(k)
# Shuffling for random distribution
random.shuffle(short_term_voyages)
random.shuffle(long_term_voyages)Splitting the data respecting the voyage length distribution:
train_voyage, test_voyage = {}, {}
# Iterate over short-term voyages:
for i, k in enumerate(short_term_voyages):
if i < int(0.8 * len(short_term_voyages)):
train_voyage[k] = voyages[k]
else: test_voyage[k] = voyages[k]
# Iterate over long-term voyages:
for i, k in enumerate(long_term_voyages):
if i < int(0.8 * len(long_term_voyages)):
train_voyage[k] = voyages[k]
else: test_voyage[k] = voyages[k]Visualizing the distribution of the dataset:
def plot_voyage_length_distribution(data, title, bar_color, min_time=144, force_print=True):
total_time = []
for key in data.keys():
for track in data[key]:
if len(track["time"]) > min_time or force_print:
total_time.append(len(track["time"]))
plot_data = pd.DataFrame({'Length': total_time})
chart = alt.Chart(plot_data).mark_bar(color=bar_color).encode(
alt.Y("count(Length):Q", title="Count", scale=alt.Scale(type="symlog")),
alt.X("Length:Q", bin=alt.Bin(maxbins=90), title="Length")
).properties(title=title, width=600, height=400)\
.configure_axisX(titleFontSize=16).configure_axisY(titleFontSize=16)\
.configure_title(fontSize=18).configure_view(strokeOpacity=0)
print("\n\n")
return chart
display(plot_voyage_length_distribution(train_voyage, "TRAINING: Distribution of Voyage Length", "#287561"))
display(plot_voyage_length_distribution(test_voyage, "TEST: Distribution of Voyage Length", "#3e57ab"))Understanding input and output timesteps and variables is crucial in trajectory forecasting tasks. Trajectory data comprises spatial coordinates and related features that depict an object's movement over time. The aim is to predict future positions of the object based on its historical data and associated features.
INPUT_TIMESTEPS: This parameter determines the consecutive observations used to predict future trajectories. Its selection impacts the model's ability to capture temporal dependencies and patterns. Too few time steps may prevent the model from capturing all movement dynamics, resulting in inaccurate predictions. Conversely, too many time steps can add noise and complexity, increasing the risk of overfitting.
INPUT_VARIABLES: Features describe each timestep in the input sequence for trajectory forecasting. These variables can include spatial coordinates, velocities, accelerations, object types, and relevant features that aid in predicting system dynamics. Choosing the right input variables is crucial; irrelevant or redundant ones may confuse the model while missing important variables can result in poor predictions.
OUTPUT_TIMESTEPS: This parameter sets the number of future time steps the model should predict, known as the prediction horizon. Choosing the right horizon size is critical. Predicting too few timesteps may not serve the application's needs while predicting too many can increase uncertainty and degrade performance. Select a value based on your application's specific requirements and data quality.
OUTPUT_VARIABLES: In trajectory forecasting, output variables include predicted spatial coordinates and sometimes other relevant features. Reducing the number of output variables can simplify prediction tasks and decrease model complexity. However, this approach might also lead to a less effective model.
Understanding the roles of input and output timesteps and variables is key to developing accurate trajectory forecasting models. By carefully selecting these elements, we can create models that effectively capture object movement dynamics, resulting in more accurate and meaningful predictions across various applications.
For this tutorial, we'll input 4 hours of data into the model to forecast the next 8 hours of vessel movement. Consequently, we'll filter out all voyages with less than 12 hours of AIS messages. By interpolating the messages every 5 minutes, we require a minimum of 144 sequential messages (12 hours at 12 messages/hour).
With data provided by AISdb, we have AIS information, including Longitude, Latitude, Course Over Ground (COG), and Speed Over Ground (SOG), representing a ship's position and movement. Longitude and Latitude specify the ship's location, while COG and SOG indicate its heading and speed. By using all features for training the neural network, our output will be the Longitude and Latitude pair. This methodology allows the model to predict the ship's future positions based on historical data.
INPUT_TIMESTEPS = 48 # 4 hours * 12 AIS messages/h
INPUT_VARIABLES = 4 # Longitude, Latitude, COG, and SOG
OUTPUT_TIMESTEPS = 96 # 8 hours * 12 AIS messages/h
OUTPUT_VARIABLES = 2 # Longitude and Latitude
NUM_WORKERS = multiprocessing.cpu_count()In this tutorial, we'll include AIS data deltas as features, which were excluded in the previous tutorial. Incorporating deltas can help the model capture temporal changes and patterns, enhancing its effectiveness in sequence-to-sequence modeling. Deltas provides information on the rate of change in features, improving the model's accuracy, especially in predicting outcomes that depend on temporal dynamics.
INPUT_VARIABLES *= 2 # Double the features with deltasdef filter_and_transform_voyages(voyages):
filtered_voyages = {}
for k, v in voyages.items():
voyages_track = []
for voyage in v:
if len(voyage["time"]) > (INPUT_TIMESTEPS + OUTPUT_TIMESTEPS):
mtx = np.vstack([voyage["lon"], voyage["lat"],
voyage["cog"], voyage["sog"]]).T
# Compute deltas
deltas = np.diff(mtx, axis=0)
# Add zeros at the first row for deltas
deltas = np.vstack([np.zeros(deltas.shape[1]), deltas])
# Concatenate the original matrix with the deltas matrix
mtx = np.hstack([mtx, deltas])
voyages_track.append(mtx)
if len(voyages_track) > 0:
filtered_voyages[k] = voyages_track
return filtered_voyages
# Checking how the data behaves for the previously set hyperparameters
display(plot_voyage_length_distribution(train_voyage, "TRAINING: Distribution of Voyage Length", "#287561",
min_time=INPUT_TIMESTEPS + OUTPUT_TIMESTEPS, force_print=False))
display(plot_voyage_length_distribution(test_voyage, "TEST: Distribution of Voyage Length", "#3e57ab",
min_time=INPUT_TIMESTEPS + OUTPUT_TIMESTEPS, force_print=False))
# Filter and transform train and test voyages and prepare for training
train_voyage = filter_and_transform_voyages(train_voyage)
test_voyage = filter_and_transform_voyages(test_voyage)def print_voyage_statistics(header, voyage_dict):
total_time = 0
for mmsi, trajectories in voyage_dict.items():
for trajectory in trajectories:
total_time += trajectory.shape[0]
print(f"{header}")
print(f"Hours of sequential data: {total_time // 12}.")
print(f"Number of unique MMSIs: {len(voyage_dict)}.", end=" \n\n")
return total_time
time_test = print_voyage_statistics("[TEST DATA]", test_voyage)
time_train = print_voyage_statistics("[TRAINING DATA]", train_voyage)
# We remained with a distribution of data that still resembles the 80-20 ratio
print(f"Training hourly-rate: {(time_train * 100) / (time_train + time_test)}%")
print(f"Test hourly-rate: {(time_test * 100) / (time_train + time_test)}%")To improve our model, we'll prioritize training samples based on trajectory straightness. We'll compute the geographical distance between a segment's start and end points using the Haversine formula. Comparing this to the total distance of all consecutive points will give a straightness metric. Our model will focus on complex trajectories with multiple direction changes, leading to better generalization and more accurate predictions.
def haversine_distance(lon_1, lat_1, lon_2, lat_2):
lon_1, lat_1, lon_2, lat_2 = map(np.radians, [lon_1, lat_1, lon_2, lat_2]) # convert latitude and longitude to radians
a = np.sin((lat_2 - lat_1) / 2) ** 2 + np.cos(lat_1) * np.cos(lat_2) * np.sin((lon_2 - lon_1) / 2) ** 2
return (2 * np.arcsin(np.sqrt(a))) * 6371000 # R: 6,371,000 metersTrajectory straightness calculation using the Haversine:
def trajectory_straightness(x):
start_point, end_point = x[0, :2], x[-1, :2]
x_coordinates, y_coordinates = x[:-1, 0], x[:-1, 1]
x_coordinates_next, y_coordinates_next = x[1:, 0], x[1:, 1]
consecutive_distances = np.array(haversine_distance(x_coordinates, y_coordinates, x_coordinates_next, y_coordinates_next))
straight_line_distance = np.array(haversine_distance(start_point[0], start_point[1], end_point[0], end_point[1]))
result = straight_line_distance / np.sum(consecutive_distances)
return result if not np.isnan(result) else 1To predict 96 data points (output) using the preceding 48 data points (input) in a trajectory time series, we create a sliding window. First, we select the initial 48 data points as the input sequence and the subsequent 96 as the output sequence. We then slide the window forward by one step and repeat the process. This continues until the end of the sequence, helping our model capture temporal dependencies and patterns in the data.
Our training strategy uses the sliding window technique, requiring unique weights for each sample. Sliding Windows (SW) transforms time series data into an appropriate format for machine learning. They generate overlapping windows with a fixed number of consecutive points by sliding the window one step at a time through the series.
def process_voyage(voyage, mmsi, max_size, overlap_size=1):
straightness_ratios, mmsis, x, y = [], [], [], []
for j in range(0, voyage.shape[0] - max_size, 1):
x_sample = voyage[(0 + j):(INPUT_TIMESTEPS + j)]
y_sample = voyage[(INPUT_TIMESTEPS + j - overlap_size):(max_size + j), 0:OUTPUT_VARIABLES]
straightness = trajectory_straightness(x_sample)
straightness_ratios.append(straightness)
x.append(x_sample.T)
y.append(y_sample.T)
mmsis.append(mmsi)
return straightness_ratios, mmsis, x, ydef process_data(voyages):
max_size = INPUT_TIMESTEPS + OUTPUT_TIMESTEPS
# Callback function to update tqdm progress bar
def process_voyage_callback(result, pbar):
pbar.update(1)
return result
with Pool(NUM_WORKERS) as pool, tqdm(total=sum(len(v) for v in voyages.values()), desc="Voyages") as pbar:
results = []
# Submit tasks to the pool and store the results
for mmsi in voyages:
for voyage in voyages[mmsi]:
callback = partial(process_voyage_callback, pbar=pbar)
results.append(pool.apply_async(process_voyage, (voyage, mmsi, max_size), callback=callback))
pool.close()
pool.join()
# Gather the results
straightness_ratios, mmsis, x, y = [], [], [], []
for result in results:
s_ratios, s_mmsis, s_x, s_y = result.get()
straightness_ratios.extend(s_ratios)
mmsis.extend(s_mmsis)
x.extend(s_x)
y.extend(s_y)
# Process the results
x, y = np.stack(x), np.stack(y)
x, y = np.transpose(x, (0, 2, 1)), np.transpose(y, (0, 2, 1))
straightness_ratios = np.array(straightness_ratios)
min_straightness, max_straightness = np.min(straightness_ratios), np.max(straightness_ratios)
scaled_straightness_ratios = (straightness_ratios - min_straightness) / (max_straightness - min_straightness)
scaled_straightness_ratios = 1. - scaled_straightness_ratios
print(f"Final number of samples = {len(x)}", end="\n\n")
return mmsis, x, y, scaled_straightness_ratios
mmsi_train, x_train, y_train, straightness_ratios = process_data(train_voyage)
mmsi_test, x_test, y_test, _ = process_data(test_voyage)In this project, the input data includes four features: Longitude, Latitude, COG (Course over Ground), and SOG (Speed over Ground), while the output data includes only Longitude and Latitude. To enhance the model's learning, we need to normalize the data through three main steps.
First, normalize Longitude, Latitude, COG, and SOG to the [0, 1] range using domain-specific parameters. This ensures the model performs well in Atlantic Canada waters by restricting the geographical scope of the AIS data and maintaining a similar scale for all features.
Second, the input and output data are standardized by subtracting the mean and dividing by the standard deviation. This centers the data around zero and scales it by its variance, preventing vanishing gradients during training.
Finally, another zero-one normalization is applied to scale the data to the [0, 1] range, aligning it with the expected range for many neural network activation functions.
def normalize_dataset(x_train, x_test, y_train,
lat_min=42, lat_max=52, lon_min=-70, lon_max=-50, max_sog=50):
def normalize(arr, min_val, max_val):
return (arr - min_val) / (max_val - min_val)
# Initial normalization
x_train[:, :, :2] = normalize(x_train[:, :, :2], np.array([lon_min, lat_min]), np.array([lon_max, lat_max]))
y_train[:, :, :2] = normalize(y_train[:, :, :2], np.array([lon_min, lat_min]), np.array([lon_max, lat_max]))
x_test[:, :, :2] = normalize(x_test[:, :, :2], np.array([lon_min, lat_min]), np.array([lon_max, lat_max]))
x_train[:, :, 2:4] = x_train[:, :, 2:4] / np.array([360, max_sog])
x_test[:, :, 2:4] = x_test[:, :, 2:4] / np.array([360, max_sog])
# Standardize X and Y
x_mean, x_std = np.mean(x_train, axis=(0, 1)), np.std(x_train, axis=(0, 1))
y_mean, y_std = np.mean(y_train, axis=(0, 1)), np.std(y_train, axis=(0, 1))
x_train = (x_train - x_mean) / x_std
y_train = (y_train - y_mean) / y_std
x_test = (x_test - x_mean) / x_std
# Final zero-one normalization
x_min, x_max = np.min(x_train, axis=(0, 1)), np.max(x_train, axis=(0, 1))
y_min, y_max = np.min(y_train, axis=(0, 1)), np.max(y_train, axis=(0, 1))
x_train = (x_train - x_min) / (x_max - x_min)
y_train = (y_train - y_min) / (y_max - y_min)
x_test = (x_test - x_min) / (x_max - x_min)
return x_train, x_test, y_train, y_mean, y_std, y_min, y_max, x_mean, x_std, x_min, x_max
return x_train, x_test, y_train, y_mean, y_std, y_min, y_max, x_mean, x_std, x_min, x_max
x_train, x_test, y_train, y_mean, y_std, y_min, y_max, x_mean, x_std, x_min, x_max = normalize_dataset(x_train, x_test, y_train)Denormalizing Y output to the original scale of the data:
def denormalize_y(y_data, y_mean, y_std, y_min, y_max,
lat_min=42, lat_max=52, lon_min=-70, lon_max=-50):
y_data = y_data * (y_max - y_min) + y_min # reverse zero-one normalization
y_data = y_data * y_std + y_mean # reverse standardization
# Reverse initial normalization for longitude and latitude
y_data[:, :, 0] = y_data[:, :, 0] * (lon_max - lon_min) + lon_min
y_data[:, :, 1] = y_data[:, :, 1] * (lat_max - lat_min) + lat_min
return y_dataDenormalizing X output to the original scale of the data:
def denormalize_x(x_data, x_mean, x_std, x_min, x_max,
lat_min=42, lat_max=52, lon_min=-70, lon_max=-50):
x_data = x_data * (x_max - x_min) + x_min # reverse zero-one normalization
x_data = x_data * x_std + x_mean # reverse standardization
# Reverse initial normalization for longitude and latitude
x_data[:, :, 0] = x_data[:, :, 0] * (lon_max - lon_min) + lon_min
x_data[:, :, 1] = x_data[:, :, 1] * (lat_max - lat_min) + lat_min
return x_datamachine-learningWe have successfully prepared the data for our machine-learning task. With the data ready, it's time for the modeling phase. Next, we will create, train, and evaluate a machine-learning model to forecast vessel trajectories using the processed dataset. Let's explore how our model performs in Atlantic Canada!
tf.keras.backend.clear_session() # Clear the Keras session to prevent potential conflicts
_ = wandb.login(force=True) # Log in to Weights & BiasesA GRU Autoencoder is a neural network that compresses and reconstructs sequential data utilizing a Gated Recurrent Unit. GRUs are highly effective at handling time-series data, which are sequential data points captured over time, as they can model intricate temporal dependencies and patterns. To perform time-series forecasting, a GRU Autoencoder can be trained on a historical time-series dataset to discern patterns and trends, subsequently compressing a sequence of future data points into a lower-dimensional representation that can be decoded to generate a forecast of the upcoming data points. With this in mind, we will begin by constructing a model architecture composed of two GRU layers with 64 units each, taking input of shape (48, 4) and (96, 4), respectively, followed by a dense layer with 2 units.
class ProbabilisticTeacherForcing(Layer):
def __init__(self, **kwargs):
super(ProbabilisticTeacherForcing, self).__init__(**kwargs)
def call(self, inputs):
decoder_gt_input, decoder_output, mixing_prob = inputs
mixing_prob = tf.expand_dims(mixing_prob, axis=-1) # Add an extra dimension for broadcasting
mixing_prob = tf.broadcast_to(mixing_prob, tf.shape(decoder_gt_input)) # Broadcast to match the shape
return tf.where(tf.random.uniform(tf.shape(decoder_gt_input)) < mixing_prob, decoder_gt_input, decoder_output)def build_model(rnn_unit="GRU", hidden_size=64):
encoder_input = Input(shape=(INPUT_TIMESTEPS, INPUT_VARIABLES), name="Encoder_Input")
decoder_gt_input = Input(shape=((OUTPUT_TIMESTEPS - 1), OUTPUT_VARIABLES), name="Decoder-GT-Input")
mixing_prob_input = Input(shape=(1,), name="Mixing_Probability")
# Encoder
encoder_gru = eval(rnn_unit)(hidden_size, activation="relu", name="Encoder")(encoder_input)
repeat_vector = RepeatVector((OUTPUT_TIMESTEPS - 1), name="Repeater")(encoder_gru)
# Inference Decoder
decoder_gru = eval(rnn_unit)(hidden_size, activation="relu", return_sequences=True, name="Decoder")
decoder_output = decoder_gru(repeat_vector, initial_state=encoder_gru)
# Adjust decoder_output shape
dense_output_adjust = TimeDistributed(Dense(OUTPUT_VARIABLES), name="Output_Adjust")
adjusted_decoder_output = dense_output_adjust(decoder_output)
# Training Decoder
decoder_gru_tf = eval(rnn_unit)(hidden_size, activation="relu", return_sequences=True, name="Decoder-TF")
probabilistic_tf_layer = ProbabilisticTeacherForcing(name="Probabilistic_Teacher_Forcing")
mixed_input = probabilistic_tf_layer([decoder_gt_input, adjusted_decoder_output, mixing_prob_input])
tf_output = decoder_gru_tf(mixed_input, initial_state=encoder_gru)
tf_output = dense_output_adjust(tf_output) # Use dense_output_adjust layer for training output
training_model = Model(inputs=[encoder_input, decoder_gt_input, mixing_prob_input], outputs=tf_output, name="Training")
inference_model = Model(inputs=encoder_input, outputs=adjusted_decoder_output, name="Inference")
return training_model, inference_model
training_model, model = build_model()def denormalize_y(y_data, y_mean, y_std, y_min, y_max, lat_min=42, lat_max=52, lon_min=-70, lon_max=-50):
scales = tf.constant([lon_max - lon_min, lat_max - lat_min], dtype=tf.float32)
biases = tf.constant([lon_min, lat_min], dtype=tf.float32)
# Reverse zero-one normalization and standardization
y_data = y_data * (y_max - y_min) + y_min
y_data = y_data * y_std + y_mean
# Reverse initial normalization for longitude and latitude
return y_data * scales + biasesdef haversine_distance(lon1, lat1, lon2, lat2):
lon1, lat1, lon2, lat2 = [tf.math.multiply(x, tf.divide(tf.constant(np.pi), 180.)) for x in [lon1, lat1, lon2, lat2]] # lat and lon to radians
a = tf.math.square(tf.math.sin((lat2 - lat1) / 2.)) + tf.math.cos(lat1) * tf.math.cos(lat2) * tf.math.square(tf.math.sin((lon2 - lon1) / 2.))
return 2 * 6371000 * tf.math.asin(tf.math.sqrt(a)) # The earth Radius is 6,371,000 metersdef custom_loss(y_true, y_pred):
tf.debugging.check_numerics(y_true, "y_true contains NaNs")
tf.debugging.check_numerics(y_pred, "y_pred contains NaNs")
# Denormalize true and predicted y
y_true_denorm = denormalize_y(y_true, y_mean, y_std, y_min, y_max)
y_pred_denorm = denormalize_y(y_pred, y_mean, y_std, y_min, y_max)
# Compute haversine distance for true and predicted y from the second time-step
true_dist = haversine_distance(y_true_denorm[:, 1:, 0], y_true_denorm[:, 1:, 1], y_true_denorm[:, :-1, 0], y_true_denorm[:, :-1, 1])
pred_dist = haversine_distance(y_pred_denorm[:, 1:, 0], y_pred_denorm[:, 1:, 1], y_pred_denorm[:, :-1, 0], y_pred_denorm[:, :-1, 1])
# Convert maximum speed from knots to meters per 5 minutes
max_speed_m_per_5min = 50 * 1.852 * 1000 * 5 / 60
# Compute the difference in distances
dist_diff = tf.abs(true_dist - pred_dist)
# Apply penalty if the predicted distance is greater than the maximum possible distance
dist_diff = tf.where(pred_dist > max_speed_m_per_5min, pred_dist - max_speed_m_per_5min, dist_diff)
# Penalty for the first output coordinate not being the same as the last input
input_output_diff = haversine_distance(y_true_denorm[:, 0, 0], y_true_denorm[:, 0, 1], y_pred_denorm[:, 0, 0], y_pred_denorm[:, 0, 1])
# Compute RMSE excluding the first element
rmse = K.sqrt(K.mean(K.square(y_true_denorm[:, 1:, :] - y_pred_denorm[:, 1:, :]), axis=1))
tf.debugging.check_numerics(y_true_denorm, "y_true_denorm contains NaNs")
tf.debugging.check_numerics(y_pred_denorm, "y_pred_denorm contains NaNs")
tf.debugging.check_numerics(true_dist, "true_dist contains NaNs")
tf.debugging.check_numerics(pred_dist, "pred_dist contains NaNs")
tf.debugging.check_numerics(dist_diff, "dist_diff contains NaNs")
tf.debugging.check_numerics(input_output_diff, "input_output_diff contains NaNs")
tf.debugging.check_numerics(rmse, "rmse contains NaNs")
# Final loss with weights
# return 0.25 * K.mean(input_output_diff) + 0.35 * K.mean(dist_diff) + 0.40 * K.mean(rmse)
return K.mean(rmse)def compile_model(model, learning_rate, clipnorm, jit_compile, skip_summary=False):
optimizer = AdamW(learning_rate=learning_rate, clipnorm=clipnorm, jit_compile=jit_compile)
model.compile(optimizer=optimizer, loss=custom_loss, metrics=["mae", "mape"], weighted_metrics=[], jit_compile=jit_compile)
if not skip_summary: model.summary() # print a summary of the model architecture
compile_model(training_model, learning_rate=0.001, clipnorm=1, jit_compile=True)
compile_model(model, learning_rate=0.001, clipnorm=1, jit_compile=True)Model: "Training"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
Encoder_Input (InputLayer) [(None, 48, 8)] 0 []
Encoder (GRU) (None, 64) 14208 ['Encoder_Input[0][0]']
Repeater (RepeatVector) (None, 95, 64) 0 ['Encoder[0][0]']
Decoder (GRU) (None, 95, 64) 24960 ['Repeater[0][0]',
'Encoder[0][0]']
Output_Adjust (TimeDistributed (None, 95, 2) 130 ['Decoder[0][0]',
) 'Decoder-TF[0][0]']
Decoder-GT-Input (InputLayer) [(None, 95, 2)] 0 []
Mixing_Probability (InputLayer [(None, 1)] 0 []
)
Probabilistic_Teacher_Forcing (None, 95, 2) 0 ['Decoder-GT-Input[0][0]',
(ProbabilisticTeacherForcing) 'Output_Adjust[0][0]',
'Mixing_Probability[0][0]']
Decoder-TF (GRU) (None, 95, 64) 13056 ['Probabilistic_Teacher_Forcing[0
][0]',
'Encoder[0][0]']
==================================================================================================
Total params: 52,354
Trainable params: 52,354
Non-trainable params: 0
__________________________________________________________________________________________________
Model: "Inference"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
Encoder_Input (InputLayer) [(None, 48, 8)] 0 []
Encoder (GRU) (None, 64) 14208 ['Encoder_Input[0][0]']
Repeater (RepeatVector) (None, 95, 64) 0 ['Encoder[0][0]']
Decoder (GRU) (None, 95, 64) 24960 ['Repeater[0][0]',
'Encoder[0][0]']
Output_Adjust (TimeDistributed (None, 95, 2) 130 ['Decoder[0][0]']
)
==================================================================================================
Total params: 39,298
Trainable params: 39,298
Non-trainable params: 0
__________________________________________________________________________________________________The following function lists callbacks used during the model training process. Callbacks are utilities at specific points during training to monitor progress or take actions based on the model's performance. The function pre-define the parameters and behavior of these callbacks:
WandbMetricsLogger: This callback logs the training and validation metrics for visualization and monitoring on the Weights & Biases (W&B) platform. This can be useful for tracking the training progress but may introduce additional overhead due to the logging process. You can remove this callback if you don't need to use W&B or want to reduce the overhead.
TerminateOnNaN: This callback terminates training if the loss becomes NaN (Not a Number) during the training process. It helps to stop the training process early when the model diverges and encounters an unstable state.
ReduceLROnPlateau: This callback reduces the learning rate by a specified factor when the monitored metric has stopped improving for several epochs. It helps fine-tune the model using a lower learning rate when it no longer improves significantly.
EarlyStopping: This callback stops the training process early when the monitored metric has not improved for a specified number of epochs. It restores the model's best weights when the training is terminated, preventing overfitting and reducing the training time.
ModelCheckpoint: This callback saves the best model (based on the monitored metric) to a file during training.
WandbMetricsLogger is the most computationally costly among these callbacks due to the logging process. You can remove this callback if you don't need to use Weights & Biases for monitoring or want to reduce overhead. The other callbacks help optimize the training process and are less computationally demanding. It's important to note that the Weights & Biases (W&B) platform is also used in other parts of the code. If you decide to remove the WandbMetricsLogger callback, please ensure that you also remove any other references to W&B in the code to avoid potential issues. If you choose to use W&B for monitoring and logging, you must register and log in to the W&B website. During the execution of the code, you'll be prompted for an authentication key to connect your script to your W&B account. This key can be obtained from your W&B account settings. Once you have the key, you can use it to enable W&B's monitoring and logging features provided by W&B.
def create_callbacks(model_name, monitor="val_loss", factor=0.2, lr_patience=3, ep_patience=12, min_lr=0, verbose=0, restore_best_weights=True, skip_wandb=False):
return ([wandb.keras.WandbMetricsLogger()] if not skip_wandb else []) + [#tf.keras.callbacks.TerminateOnNaN(),
ReduceLROnPlateau(monitor=monitor, factor=factor, patience=lr_patience, min_lr=min_lr, verbose=verbose),
EarlyStopping(monitor=monitor, patience=ep_patience, verbose=verbose, restore_best_weights=restore_best_weights),
tf.keras.callbacks.ModelCheckpoint(os.path.join(ROOT, MODELS, model_name), monitor="val_loss", mode="min", save_best_only=True, verbose=verbose)]def train_model(model, x_train, y_train, batch_size, epochs, validation_split, model_name):
run = wandb.init(project="kAISdb", anonymous="allow") # start the wandb run
# Set the initial mixing probability
mixing_prob = 0.5
# Update y_train to have the same dimensions as the output
y_train = y_train[:, :(OUTPUT_TIMESTEPS - 1), :]
# Create the ground truth input for the decoder by appending a padding at the beginning of the sequence
decoder_ground_truth_input_data = (np.zeros((y_train.shape[0], 1, y_train.shape[2])), y_train[:, :-1, :])
decoder_ground_truth_input_data = np.concatenate(decoder_ground_truth_input_data, axis=1)
try:
# Train the model with Teacher Forcing
with tf.device(tf.test.gpu_device_name()):
training_model.fit([x_train, decoder_ground_truth_input_data, np.full((x_train.shape[0], 1), mixing_prob)], y_train, batch_size=batch_size, epochs=epochs,
verbose=2, validation_split=validation_split, callbacks=create_callbacks(model_name))
# , sample_weight=straightness_ratios)
except KeyboardInterrupt as e:
print("\nRestoring best weights [...]")
# Load the weights of the teacher-forcing model
training_model.load_weights(model_name)
# Transfering the weights to the inference model
for layer in model.layers:
if layer.name in [l.name for l in training_model.layers]:
layer.set_weights(training_model.get_layer(layer.name).get_weights())
run.finish() # finish the wandb run
model_name = "TF-GRU-AE.h5"
full_path = os.path.join(ROOT, MODELS, model_name)
if True:#not os.path.exists(full_path):
train_model(model, x_train, y_train, batch_size=1024,
epochs=250, validation_split=0.2,
model_name=model_name)
else:
training_model.load_weights(full_path)
for layer in model.layers: # inference model initialization
if layer.name in [l.name for l in training_model.layers]:
layer.set_weights(training_model.get_layer(layer.name).get_weights())def evaluate_model(model, x_test, y_test, y_mean, y_std, y_min, y_max, y_pred=None):
def single_trajectory_error(y_test, y_pred, index):
distances = haversine_distance(y_test[index, :, 0], y_test[index, :, 1], y_pred[index, :, 0], y_pred[index, :, 1])
return np.min(distances), np.max(distances), np.mean(distances), np.median(distances)
# Modify this function to handle teacher-forced models with 95 output variables instead of 96
def all_trajectory_error(y_test, y_pred):
errors = [single_trajectory_error(y_test[:, 1:], y_pred, i) for i in range(y_test.shape[0])]
min_errors, max_errors, mean_errors, median_errors = zip(*errors)
return min(min_errors), max(max_errors), np.mean(mean_errors), np.median(median_errors)
def plot_trajectory(x_test, y_test, y_pred, sample_index):
min_error, max_error, mean_error, median_error = single_trajectory_error(y_test, y_pred, sample_index)
fig = go.Figure()
fig.add_trace(go.Scatter(x=x_test[sample_index, :, 0], y=x_test[sample_index, :, 1], mode="lines", name="Input Data", line=dict(color="green")))
fig.add_trace(go.Scatter(x=y_test[sample_index, :, 0], y=y_test[sample_index, :, 1], mode="lines", name="Ground Truth", line=dict(color="blue")))
fig.add_trace(go.Scatter(x=y_pred[sample_index, :, 0], y=y_pred[sample_index, :, 1], mode="lines", name="Forecasted Trajectory", line=dict(color="red")))
fig.update_layout(title=f"Sample Index: {sample_index} | Distance Errors (in meteres):<br>Min: {min_error:.2f}m, Max: {max_error:.2f}m, "
f"Mean: {mean_error:.2f}m, Median: {median_error:.2f}m", xaxis_title="Longitude", yaxis_title="Latitude",
plot_bgcolor="#e4eaf0", paper_bgcolor="#fcfcfc", width=700, height=600)
max_lon, max_lat = -58.705587131108196, 47.89066160591873
min_lon, min_lat = -61.34247286889181, 46.09201839408127
fig.update_xaxes(range=[min_lon, max_lon])
fig.update_yaxes(range=[min_lat, max_lat])
return fig
if y_pred is None:
with tf.device(tf.test.gpu_device_name()):
y_pred = model.predict(x_test, verbose=0)
y_pred_o = y_pred # preserve the result
x_test = denormalize_x(x_test, x_mean, x_std, x_min, x_max)
y_pred = denormalize_y(y_pred_o, y_mean, y_std, y_min, y_max)
# Modify this line to handle teacher-forced models with 95 output variables instead of 96
for sample_index in [1000, 2500, 5000, 7500]:
display(plot_trajectory(x_test, y_test[:, 1:], y_pred, sample_index))
# The metrics require a lower dimension (no impact on the results)
y_test_reshaped = np.reshape(y_test[:, 1:], (-1, y_test.shape[2]))
y_pred_reshaped = np.reshape(y_pred, (-1, y_pred.shape[2]))
# Physical Distance Error given in meters
all_min_error, all_max_error, all_mean_error, all_median_error = all_trajectory_error(y_test, y_pred)
print("\nAll Trajectories Min DE: {:.4f}m".format(all_min_error))
print("All Trajectories Max DE: {:.4f}m".format(all_max_error))
print("All Trajectories Mean DE: {:.4f}m".format(all_mean_error))
print("All Trajectories Median DE: {:.4f}m".format(all_median_error))
# Calculate evaluation metrics on the test data
r2 = r2_score(y_test_reshaped, y_pred_reshaped)
mse = mean_squared_error(y_test_reshaped, y_pred_reshaped)
mae = mean_absolute_error(y_test_reshaped, y_pred_reshaped)
evs = explained_variance_score(y_test_reshaped, y_pred_reshaped)
mape = mean_absolute_percentage_error(y_test_reshaped, y_pred_reshaped)
rmse = np.sqrt(mse)
print(f"\nTest R^2: {r2:.4f}")
print(f"Test MAE: {mae:.4f}")
print(f"Test MSE: {mse:.4f}")
print(f"Test RMSE: {rmse:.4f}")
print(f"Test MAPE: {mape:.4f}")
print(f"Test Explained Variance Score: {evs:.4f}")
return y_pred_o
_ = evaluate_model(model, x_test, y_test, y_mean, y_std, y_min, y_max)In this step, we define a function called model_placeholder that uses the Keras Tuner to create a model with tunable hyperparameters. The function takes a hyperparameter object as input, which defines the search space for the hyperparameters of interest. Specifically, we are searching for the best number of units in the encoder and decoder GRU layers and the optimal learning rate for the AdamW optimizer. The model_placeholder function constructs a GRU-AutoEncoder model with these tunable hyperparameters and compiles the model using the Mean Absolute Error (MAE) as the loss function. Keras Tuner will use this model during the hyperparameter optimization process to find the best combination of hyperparameters that minimizes the validation loss at the expanse of long computing time.
Helper for saving the training history:
def save_history(history, model_name):
history_name = model_name.replace('.h5', '.pkl')
history_name = os.path.join(ROOT, MODELS, history_name)
with open(history_name, 'wb') as f:
pkl.dump(history, f)Helper for restoring the training history:
def load_history(model_name):
history_name = model_name.replace('.h5', '.pkl')
history_name = os.path.join(ROOT, MODELS, history_name)
with open(history_name, 'rb') as f:
history = pkl.load(f)
return historyDefining the model to be optimized:
def build_model(rnn_unit="GRU", enc_units_1=64, dec_units_1=64):
encoder_input = Input(shape=(INPUT_TIMESTEPS, INPUT_VARIABLES), name="Encoder_Input")
decoder_gt_input = Input(shape=((OUTPUT_TIMESTEPS - 1), OUTPUT_VARIABLES), name="Decoder-GT-Input")
mixing_prob_input = Input(shape=(1,), name="Mixing_Probability")
# Encoder
encoder_gru = eval(rnn_unit)(enc_units_1, activation="relu", name="Encoder")(encoder_input)
repeat_vector = RepeatVector((OUTPUT_TIMESTEPS - 1), name="Repeater")(encoder_gru)
# Inference Decoder
decoder_gru = eval(rnn_unit)(dec_units_1, activation="relu", return_sequences=True, name="Decoder")
decoder_output = decoder_gru(repeat_vector, initial_state=encoder_gru)
# Adjust decoder_output shape
dense_output_adjust = TimeDistributed(Dense(OUTPUT_VARIABLES), name="Output_Adjust")
adjusted_decoder_output = dense_output_adjust(decoder_output)
# Training Decoder
decoder_gru_tf = eval(rnn_unit)(dec_units_1, activation="relu", return_sequences=True, name="Decoder-TF")
probabilistic_tf_layer = ProbabilisticTeacherForcing(name="Probabilistic_Teacher_Forcing")
mixed_input = probabilistic_tf_layer([decoder_gt_input, adjusted_decoder_output, mixing_prob_input])
tf_output = decoder_gru_tf(mixed_input, initial_state=encoder_gru)
tf_output = dense_output_adjust(tf_output) # Use dense_output_adjust layer for training output
training_model = Model(inputs=[encoder_input, decoder_gt_input, mixing_prob_input], outputs=tf_output, name="Training")
inference_model = Model(inputs=encoder_input, outputs=adjusted_decoder_output, name="Inference")
return training_model, inference_modelHyperOpt Objective Function:
def objective(hyperparams, x_train, y_train, straightness_ratios, model_prefix):
# Get the best hyperparameters from the optimization results
enc_units_1 = hyperparams["enc_units_1"]
dec_units_1 = hyperparams["dec_units_1"]
mixing_prob = hyperparams["mixing_prob"]
lr = hyperparams["learning_rate"]
# Create the model name using the best hyperparameters
model_name = f"{model_prefix}-{enc_units_1}-{dec_units_1}-{mixing_prob}-{lr}.h5"
full_path = os.path.join(ROOT, MODELS, model_name) # best model full path
# Check if the model results file with this name already exists
if not os.path.exists(full_path.replace(".h5", ".pkl")):
print(f"Saving under {model_name}.")
# Define the model architecture
training_model, _ = build_model(enc_units_1=enc_units_1, dec_units_1=dec_units_1)
compile_model(training_model, learning_rate=lr, clipnorm=1, jit_compile=True, skip_summary=True)
# Update y_train to have the same dimensions as the output
y_train = y_train[:, :(OUTPUT_TIMESTEPS - 1), :]
# Create the ground truth input for the decoder by appending a padding at the beginning of the sequence
decoder_ground_truth_input_data = (np.zeros((y_train.shape[0], 1, y_train.shape[2])), y_train[:, :-1, :])
decoder_ground_truth_input_data = np.concatenate(decoder_ground_truth_input_data, axis=1)
# Train the model on the data, using GPU if available
with tf.device(tf.test.gpu_device_name()):
history = training_model.fit([x_train, decoder_ground_truth_input_data, np.full((x_train.shape[0], 1), mixing_prob)], y_train,
batch_size=10240, epochs=250, validation_split=.2, verbose=0,
workers=multiprocessing.cpu_count(), use_multiprocessing=True,
callbacks=create_callbacks(model_name, skip_wandb=True))
#, sample_weight=straightness_ratios)
# Save the training history
save_history(history.history, model_name)
# Clear the session to release resources
del training_model; tf.keras.backend.clear_session()
else:
print("Loading pre-trained weights.")
history = load_history(model_name)
if type(history) == dict: # validation loss of the model
return {"loss": history["val_loss"][-1], "status": STATUS_OK}
else: return {"loss": history.history["val_loss"][-1], "status": STATUS_OK}def optimize_hyperparameters(max_evals, model_prefix, x_train, y_train, sample_size=5000):
def build_space(n_min=2, n_steps=9):
# Defining a custom 2^N range function
n_range = lambda n_min, n_steps: np.array(
[2**n for n in range(n_min, n_steps) if 2**n >= n_min])
# Defining the unconstrained search space
encoder_1_range = n_range(n_min, n_steps)
decoder_1_range = n_range(n_min, n_steps)
learning_rate_range = [.01, .001, .0001]
mixing_prob_range = [.25, .5, .75]
# Enforcinf contraints to the search space
enc_units_1 = np.random.choice(encoder_1_range)
dec_units_1 = np.random.choice(decoder_1_range[np.where(decoder_1_range == enc_units_1)])
learning_rate = np.random.choice(learning_rate_range)
mixing_prob = np.random.choice(mixing_prob_range)
# Returns a single element of the search space
return dict(enc_units_1=enc_units_1, dec_units_1=dec_units_1, learning_rate=learning_rate, mixing_prob=mixing_prob)
# Select the search space based on a pre-set sampled random space
search_space = hp.choice("hyperparams", [build_space() for _ in range(sample_size)])
trials = Trials() # initialize Hyperopt trials
# Define the objective function for Hyperopt
fn = lambda hyperparams: objective(hyperparams, x_train, y_train, straightness_ratios, model_prefix)
# Perform Hyperopt optimization and find the best hyperparameters
best = fmin(fn=fn, space=search_space, algo=tpe.suggest, max_evals=max_evals, trials=trials)
best_hyperparams = space_eval(search_space, best)
# Get the best hyperparameters from the optimization results
enc_units_1 = best_hyperparams["enc_units_1"]
dec_units_1 = best_hyperparams["dec_units_1"]
mixing_prob = best_hyperparams["mixing_prob"]
lr = best_hyperparams["learning_rate"]
# Create the model name using the best hyperparameters
model_name = f"{model_prefix}-{enc_units_1}-{dec_units_1}-{mixing_prob}-{lr}.h5"
full_path = os.path.join(ROOT, MODELS, model_name) # best model full path
t_model, i_model = build_model(enc_units_1=enc_units_1, dec_units_1=dec_units_1)
t_model = tf.keras.models.load_model(full_path)
for layer in i_model.layers: # inference model initialization
if layer.name in [l.name for l in t_model.layers]:
layer.set_weights(t_model.get_layer(layer.name).get_weights())
print(f"Best hyperparameters:")
print(f" Encoder units 1: {enc_units_1}")
print(f" Decoder units 1: {dec_units_1}")
print(f" Mixing proba.: {mixing_prob}")
print(f" Learning rate: {lr}")
return i_model
max_evals, model_prefix = 100, "TF-GRU"
# best_model = optimize_hyperparameters(max_evals, model_prefix, x_train, y_train)
# [NOTE] YOU CAN SKIP THIS STEP BY LOADING THE PRE-TRAINED WEIGHTS ON THE NEXT CELL.Swiping the project folder for other pre-trained weights shared with this tutorial:
def find_best_model(root_folder, model_prefix):
best_model_name, best_val_loss = None, float('inf')
for f in os.listdir(root_folder):
if (f.endswith(".h5") and f.startswith(model_prefix)):
try:
history = load_history(f)
# Get the validation loss
if type(history) == dict:
val_loss = history["val_loss"][-1]
else: val_loss = history.history["val_loss"][-1]
# Storing the best model
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_name = f
except: pass
# Load the best model
full_path = os.path.join(ROOT, MODELS, best_model_name)
t_model, i_model = build_model(enc_units_1=int(best_model_name.split("-")[2]), dec_units_1=int(best_model_name.split("-")[3]))
t_model = tf.keras.models.load_model(full_path, custom_objects={"ProbabilisticTeacherForcing": ProbabilisticTeacherForcing})
for layer in i_model.layers: # inference model initialization
if layer.name in [l.name for l in t_model.layers]:
layer.set_weights(t_model.get_layer(layer.name).get_weights())
# Print summary of the best model
print(f"Loss: {best_val_loss}")
i_model.summary()
return i_model
best_model = find_best_model(os.path.join(ROOT, MODELS), model_prefix)_ = evaluate_model(best_model, x_test, y_test, y_mean, y_std, y_min, y_max)Deep learning models, although powerful, are often criticized for their lack of explainability, making it difficult to comprehend their decision-making process and raising concerns about trust and reliability. To address this issue, we can use techniques like the PFI method, a simple, model-agnostic approach that helps visualize the importance of features in deep learning models. This method works by shuffling individual feature values in the dataset and observing the impact on the model's performance. By measuring the change in a designated metric when each feature's values are randomly permuted, we can infer the importance of that specific feature. The idea is that if a feature is crucial for the model's performance, shuffling its values should lead to a significant shift in performance; otherwise if a feature has little impact, its value permutation should result in a minor change. Applying the permutation feature importance method to the best model, obtained after hyperparameter tuning, can give us a more transparent understanding of how the model makes its decisions.
def permutation_feature_importance(model, x_test, y_test, metric):
# Function to calculate permutation feature importance
def PFI(model, x, y_true, metric):
# Reshape the true values for easier comparison with predictions
y_true = np.reshape(y_true, (-1, y_true.shape[2]))
# Predict using the model and reshape the predicted values
with tf.device(tf.test.gpu_device_name()):
y_pred = model.predict(x, verbose=0)
y_pred = np.reshape(y_pred, (-1, y_pred.shape[2]))
# Calculate the baseline score using the given metric
baseline_score = metric(y_true, y_pred)
# Initialize an array for feature importances
feature_importances = np.zeros(x.shape[2])
# Calculate the importance for each feature
for feature_idx in range(x.shape[2]):
x_permuted = x.copy()
x_permuted[:, :, feature_idx] = np.random.permutation(x[:, :, feature_idx])
# Predict using the permuted input and reshape the predicted values
with tf.device(tf.test.gpu_device_name()):
y_pred_permuted = model.predict(x_permuted, verbose=0)
y_pred_permuted = np.reshape(y_pred_permuted, (-1, y_pred_permuted.shape[2]))
# Calculate the score with permuted input
permuted_score = metric(y_true, y_pred_permuted)
# Compute the feature importance as the difference between permuted and baseline scores
feature_importances[feature_idx] = permuted_score - baseline_score
return feature_importances
feature_importances = PFI(model, x_test, y_test, metric)
# Prepare the data for plotting (require a dataframe)
feature_names = ["Longitude", "Latitude", "COG", "SOG"]
feature_importance_df = pd.DataFrame({"features": feature_names, "importance": feature_importances})
# Create the bar plot with Altair
bar_plot = alt.Chart(feature_importance_df).mark_bar(size=40, color="mediumblue", opacity=0.8).encode(
x=alt.X("features:N", title="Features", axis=alt.Axis(labelFontSize=12, titleFontSize=14)),
y=alt.Y("importance:Q", title="Permutation Importance", axis=alt.Axis(labelFontSize=12, titleFontSize=14)),
).properties(title=alt.TitleParams(text="Feature Importance", fontSize=16, fontWeight="bold"), width=400, height=300)
return bar_plot, feature_importances
permutation_feature_importance(best_model, x_test, y_test, mean_absolute_error)[0].display()Permutation feature importance has some limitations, such as assuming features are independent and producing biased results when features are highly correlated. It also doesn't provide detailed explanations for individual data points. An alternative is sensitivity analysis, which studies how input features affect model predictions. By perturbing each input feature individually and observing the prediction changes, we can understand which features significantly impact the model's output. This approach offers insights into the model's decision-making process and helps identify influential features. However, it does not account for feature interactions and can be computationally expensive for many features or perturbation steps.
def sensitivity_analysis(model, x_sample, perturbation_range=(-0.1, 0.1), num_steps=10, plot_nrows=4):
# Get the number of features and outputs
num_features = x_sample.shape[1]
num_outputs = model.output_shape[-1] * model.output_shape[-2]
# Create an array of perturbations
perturbations = np.linspace(perturbation_range[0], perturbation_range[1], num_steps)
# Initialize sensitivity array
sensitivity = np.zeros((num_features, num_outputs, num_steps))
# Get the original prediction for the input sample
original_prediction = model.predict(x_sample.reshape(1, -1, 4), verbose=0).reshape(-1)
# Iterate over input features and perturbations
for feature_idx in range(num_features):
for i, perturbation in enumerate(perturbations):
# Create a perturbed version of the input sample
perturbed_sample = x_sample.copy()
perturbed_sample[:, feature_idx] += perturbation
# Get the prediction for the perturbed input sample
perturbed_prediction = model.predict(perturbed_sample.reshape(1, -1, 4), verbose=0).reshape(-1)
# Calculate the absolute prediction change and store it in the sensitivity array
sensitivity[feature_idx, :, i] = np.abs(perturbed_prediction - original_prediction)
# Determine the number of rows and columns in the plot
ncols = 6
nrows = max(min(plot_nrows, math.ceil(num_outputs / ncols)), 1)
# Define feature names
feature_names = ["Longitude", "Latitude", "COG", "SOG"]
# Create the sensitivity plot
fig, axs = plt.subplots(nrows, ncols, figsize=(18, 3 * nrows), sharex=True, sharey=True)
axs = axs.ravel()
output_idx = 0
for row in range(nrows):
for col in range(ncols):
if output_idx < num_outputs:
# Plot sensitivity curves for each feature
for feature_idx in range(num_features):
axs[output_idx].plot(perturbations, sensitivity[feature_idx, output_idx], label=f'{feature_names[feature_idx]}')
# Set the title for each subplot
axs[output_idx].set_title(f'Output {output_idx // 2 + 1}, {"Longitude" if output_idx % 2 == 0 else "Latitude"}')
output_idx += 1
# Set common labels and legend
fig.text(0.5, 0.04, 'Perturbation', ha='center', va='center')
fig.text(0.06, 0.5, 'Absolute Prediction Change', ha='center', va='center', rotation='vertical')
handles, labels = axs[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', ncol=num_features, bbox_to_anchor=(.5, .87))
plt.tight_layout()
plt.subplots_adjust(top=0.8, bottom=0.1, left=0.1, right=0.9)
plt.show()
return sensitivity
x_sample = x_test[100] # Select a sample from the test set
sensitivity = sensitivity_analysis(best_model, x_sample)UMAP is a nonlinear dimensionality reduction technique that visualizes high-dimensional data in a lower-dimensional space, preserving the local and global structure. In trajectory forecasting, UMAP can project high-dimensional model representations into 2D or 3D to clarify the relationships between input features and outputs. Unlike sensitivity analysis, which measures prediction changes due to input feature perturbations, UMAP reveals data structure without perturbations. It also differs from feature permutation, which evaluates feature importance by shuffling values and assessing model performance changes. UMAP focuses on visualizing intrinsic data structures and relationships.
def visualize_intermediate_representations(model, x_test_subset, y_test_subset, n_neighbors=15, min_dist=0.1, n_components=2):
# Extract intermediate representations from your model
intermediate_layer_model = keras.Model(inputs=model.input, outputs=model.layers[-2].output)
intermediate_output = intermediate_layer_model.predict(x_test_subset, verbose=0)
# Flatten the last two dimensions of the intermediate_output
flat_intermediate_output = intermediate_output.reshape(intermediate_output.shape[0], -1)
# UMAP
reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, random_state=seed_value)
umap_output = reducer.fit_transform(flat_intermediate_output)
# Convert y_test_subset to strings
y_test_str = np.array([str(label) for label in y_test_subset])
# Map string labels to colors
unique_labels = np.unique(y_test_str)
colormap = plt.cm.get_cmap('viridis', len(unique_labels))
label_to_color = {label: colormap(i) for i, label in enumerate(unique_labels)}
colors = np.array([label_to_color[label] for label in y_test_str])
# Create plot with Matplotlib
fig, ax = plt.subplots(figsize=(10, 8))
sc = ax.scatter(umap_output[:, 0], umap_output[:, 1], c=colors, s=5)
ax.set_title("UMAP Visualization", fontsize=14, fontweight="bold")
ax.set_xlabel("X Dimension", fontsize=12)
ax.set_ylabel("Y Dimension", fontsize=12)
ax.grid(True, linestyle='--', alpha=0.5)
# Add a colorbar to the plot
sm = plt.cm.ScalarMappable(cmap=colormap, norm=plt.Normalize(vmin=0, vmax=len(unique_labels)-1))
sm.set_array([])
cbar = plt.colorbar(sm, ticks=range(len(unique_labels)), ax=ax)
cbar.ax.set_yticklabels(unique_labels)
cbar.set_label("MMSIs")
plt.show()
visualize_intermediate_representations(best_model, x_test[:10000], mmsi_test[:10000], n_neighbors=10, min_dist=0.5)GRUs can effectively forecast vessel trajectories but have notable downsides. A primary limitation is their struggle with long-term dependencies due to the vanishing gradient problem, causing the loss of relevant information from earlier time steps. This makes capturing long-term patterns in vessel trajectories challenging. Additionally, GRUs are computationally expensive with large datasets and long sequences, resulting in longer training times and higher memory use. While outperforming basic RNNs, they may not always surpass advanced architectures like LSTMs or Transformer models. Furthermore, the interpretability of GRU-based models is a challenge, which can hinder their adoption in safety-critical applications like vessel trajectory forecasting.
AIS data is messy. It comes as a stream of latitude, longitude, timestamps, and a few metadata fields. By themselves, they’re hard to compare or feed into downstream ML models. What we really want is
Inspired by word2vec, traj2vec applies the same logic to movement data: instead of predicting the next word, it predicts the next location in a sequence. Just like words gain meaning from context, vessel positions gain meaning from their trajectory history.
The result: trajectories that “look alike” end up close together in embedding space. For instance, two ferries running parallel routes will embed similarly, while a cargo vessel crossing the Gulf of Mexico will sit far away from a fishing boat looping off the coast.
import os
import h3
import json
import aisdb
import cartopy.feature as cfeature
import cartopy.crs as ccrs
from aisdb.database.dbconn import PostgresDBConn
from aisdb.denoising_encoder import encode_greatcircledistance, InlandDenoising
from aisdb.track_gen import min_speed_filter, min_track_length_filter
from aisdb.database import sqlfcn
from datetime import datetime, timedelta
from collections import defaultdict
from tqdm import tqdm
import pprint
import numpy as np
import geopandas as gpd
import matplotlib.pyplot as plt
import nest_asyncio
nest_asyncio.apply()This function pulls raw AIS data from a database, denoises it, splits tracks into time-consistent segments, filters outliers, and interpolates them at fixed time steps. The result is a set of clean, continuous vessel trajectories ready for embedding
def process_interval(dbconn, start, end):
# Open a new connection with the database
qry = aisdb.DBQuery(dbconn=dbconn, start=start, end=end,
xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax,
callback=aisdb.database.sqlfcn_callbacks.in_bbox_time_validmmsi)
# Decimate is for removing unnecessary points in the trajectory
rowgen = qry.gen_qry(fcn=sqlfcn.crawl_dynamic_static)
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
with InlandDenoising(data_dir='./data/tmp/') as remover:
cleaned_tracks = remover.filter_noisy_points(tracks)
# Split the tracks based on the time between transmissions
track_segments = aisdb.track_gen.split_timedelta(cleaned_tracks, time_split)
# Filter out segments that are below the min-score threshold
tracks_encoded = encode_greatcircledistance(track_segments, distance_threshold=distance_split, speed_threshold=speed_split)
tracks_encoded = min_speed_filter(tracks_encoded, minspeed=1)
# Interpolate the segments every five minutes to enforce continuity
tracks_interpolated = aisdb.interp.interp_time(tracks_encoded, step=timedelta(minutes=1))
# Returns a generator from a Python variable
return list(tracks_interpolated)We load the study region (Gulf shapefile) and a hexagonal grid (H3 resolution 6). These will be used to map vessel positions into discrete spatial cells — the “tokens” for our trajectory embedding model.
# Load the shapefile
gulf_shapefile = './data/region/gulf.shp'
print(f"Loading shapefile from {gulf_shapefile}...")
gdf_gulf = gpd.read_file(gulf_shapefile)
gdf_hexagons = gpd.read_file('./data/cell/Hexagons_6.shp')
gdf_hexagons = gdf_hexagons.to_crs(epsg=4326) # Consistent CRS projectionEach trajectory is converted from lat/lon coordinates into H3 hexagon IDs at resolution 6. To avoid redundant entries, we deduplicate consecutive identical cells while keeping the timestamp of first entry. The result is a sequence of discrete spatial tokens with time information — the input format for traj2vec.
# valid_h3_ids = set(gdf_hexagons['hex_id'])
bounding_box = gdf_hexagons.total_bounds # Extract the bounding box
# bounding_box = gdf_gulf.total_bounds # Extract the bounding box
xmin, ymin, xmax, ymax = bounding_box # Split the bounding box
start_date = datetime(2023, 1, 1)
end_date = datetime(2023, 1, 30)
print(f"Processing trajectories from {start_date} to {end_date}")
# Define pre-processing parameters
time_split = timedelta(hours=3)
distance_split = 10000 # meters
speed_split = 40 # knots
cell_visits = defaultdict(lambda: defaultdict(list))
filtered_visits = defaultdict(lambda: defaultdict(list))
g2h3_vec = np.vectorize(h3.latlng_to_cell)
pp = pprint.PrettyPrinter(indent=4)
track_info_list = []
track_list = process_interval(dbconn, start_date, end_date)
for track in tqdm(track_list, total=len(track_list), desc="Vessels", leave=False):
h3_ids = g2h3_vec(track['lat'], track['lon'], 6)
timestamps = track['time']
# Identify the entry points of cells on a track
# Deduplicate consecutive identical h3_ids while preserving the entry timestamp
dedup_h3_ids = [h3_ids[0]]
dedup_timestamps = [timestamps[0]]
for i in range(1, len(h3_ids)):
if h3_ids[i] != dedup_h3_ids[-1]:
dedup_h3_ids.append(h3_ids[i])
dedup_timestamps.append(timestamps[i])
track_info = {
"mmsi": track['mmsi'],
"h3_seq": dedup_h3_ids,
"timestamp_seq": dedup_timestamps
}
track_info_list.append(track_info)Before training embeddings, it’s useful to check how long our AIS trajectories are. The function below computes summary statistics (min, max, mean, percentiles) and plots the distribution of track lengths in terms of H3 cells.
In our dataset, the distribution is skewed to the left — most vessel tracks are relatively short, with only a few very long trajectories.
For a simpler visualization, we also plot trajectories in raw lat/lon space without cartographic features. This is handy for debugging and checking if preprocessing (deduplication, interpolation) worked correctly.
import seaborn as sns
def plot_length_distribution(track_lengths):
# Compute summary stats
length_stats = {
"min": int(np.min(track_lengths)),
"max": int(np.max(track_lengths)),
"mean": float(np.mean(track_lengths)),
"median": float(np.median(track_lengths)),
"percentiles": {
"10%": int(np.percentile(track_lengths, 10)),
"25%": int(np.percentile(track_lengths, 25)),
"50%": int(np.percentile(track_lengths, 50)),
"75%": int(np.percentile(track_lengths, 75)),
"90%": int(np.percentile(track_lengths, 90)),
"95%": int(np.percentile(track_lengths, 95)),
}
}
print(length_stats)
# Plot distribution
plt.figure(figsize=(10, 6))
sns.histplot(track_lengths, bins=100, kde=True)
plt.title("Distribution of Track Lengths")
plt.xlabel("Track Length (number of H3 cells)")
plt.ylabel("Frequency")
plt.grid(True)
plt.tight_layout()
plt.show()
def map_view(tracks, dot_size=3, color=None, save=False, path=None, bbox=None, line=False, line_width=0.5, line_opacity=0.3):
fig = plt.figure(figsize=(16, 9))
ax = plt.axes(projection=ccrs.PlateCarree())
# Add cartographic features
ax.add_feature(cfeature.OCEAN.with_scale('10m'), facecolor='#E0E0E0')
ax.add_feature(cfeature.LAND.with_scale('10m'), facecolor='#FFE5CC')
ax.add_feature(cfeature.BORDERS, linestyle=':')
ax.add_feature(cfeature.LAKES, alpha=0.5)
ax.add_feature(cfeature.RIVERS)
ax.coastlines(resolution='10m')
if line:
for track in tqdm(tracks):
ax.plot(track['lon'], track['lat'], color=color, linewidth=line_width, alpha=line_opacity, transform=ccrs.PlateCarree())
else:
for track in tqdm(tracks):
ax.scatter(track['lon'], track['lat'], c=color, s=dot_size, transform=ccrs.PlateCarree())
if bbox:
# Set the map extent based on a bounding box
ax.set_extent(bbox, crs=ccrs.PlateCarree())
ax.gridlines(draw_labels=True)
if save:
plt.savefig(path, dpi=300, transparent=True)
plt.show()
def hex_view(lats, lons, save=True):
plt.figure(figsize=(8,8))
for traj_lat, traj_lon in zip(lats, lons):
plt.plot(traj_lon, traj_lat, alpha=0.3, linewidth=1)
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title("Test Trajectories")
plt.axis("equal")
if save:
plt.savefig("img/test_track.png", dpi=300)
Filtering out some values
track_info_list = [t for t in track_info_list if (len(t['h3_seq']) >= 10)&(len(t['h3_seq']) <= 300)]
We collect all unique H3 IDs from the trajectories and assign each one an integer index. Just like in NLP, we also reserve special tokens for padding, start, and end of sequence. This turns spatial cells into a vocabulary that our embedding model can work with.
Each vessel track is then mapped from its H3 sequence into an integer sequence (int_seq). We also convert the H3 cells back into lat/lon pairs for later visualization. At this point, the data is ready to be fed into a traj2vec-style model.
vec_cell_to_latlng = np.vectorize(h3.cell_to_latlng)
# Extract hex ids from all tracks
all_h3_ids = set()
for track in track_info_list:
all_h3_ids.update(track['h3_seq']) # or t['int_seq'] if already mapped
# Build vocab: reserve 0,1,2 for BOS, EOS, PAD
h3_vocab = {h: i+3 for i, h in enumerate(sorted(all_h3_ids))}
# h3_vocab = {h: i+3 for i, h in enumerate(sorted(valid_h3_ids))}
special_tokens = {"<PAD>": 0, "<BOS>": 1, "<EOS>": 2}
h3_vocab.update(special_tokens)
for t in track_info_list:
t["int_seq"] = [h3_vocab[h] for h in t["h3_seq"] if h in h3_vocab]
t["lat"], t["lon"] = vec_cell_to_latlng(t.get('h3_seq'))
We split the cleaned trajectories into train, validation, and test sets. This ensures our model can be trained, tuned, and evaluated fairly without data leakage.
Each trajectory is written out in multiple aligned formats:
.src → input sequence (all tokens except last)
.trg → target sequence (all tokens except first)
.lat / .lon → raw geographic coordinates (for visualization)
.t → the complete trajectory sequence
This setup mirrors NLP datasets, where models learn to predict the “next token” in a sequence.
from sklearn.model_selection import train_test_split
# Initial split: train vs temp
train_tracks, temp_tracks = train_test_split(track_info_list, test_size=0.4, random_state=42)
# Second split: validation vs test
val_tracks, test_tracks = train_test_split(temp_tracks, test_size=0.5, random_state=42)
def save_data(tracks, prefix, output_dir="data"):
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, f"{prefix}.src"), "w") as f_src, \
open(os.path.join(output_dir, f"{prefix}.trg"), "w") as f_trg, \
open(os.path.join(output_dir, f"{prefix}.lat"), "w") as f_lat, \
open(os.path.join(output_dir, f"{prefix}.lon"), "w") as f_lon, \
open(os.path.join(output_dir, f"{prefix}_trj.t"), "w") as f_t:
for idx, t in enumerate(tracks):
ids = t["int_seq"]
src = ids[:-1]
trg = ids[1:]
f_t.write(" ".join(map(str, ids)) + "\n") # the whole track, t = src U trg
f_src.write(" ".join(map(str, src)) + "\n")
f_trg.write(" ".join(map(str, trg)) + "\n")
f_lat.write(" ".join(map(str, t.get('lat'))) + "\n")
f_lon.write(" ".join(map(str, t.get('lon'))) + "\n")
def save_h3_vocab(h3_vocab, output_dir="data", filename="vocab.json"):
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, filename), "w") as f:
json.dump(h3_vocab, f, indent=2)
save_data(train_tracks, "train")
save_data(val_tracks, "val")
save_data(test_tracks, "test")
save_h3_vocab(h3_vocab) # save the INT index mapping to H3 index
lats = [np.fromstring(line, sep=' ') for line in open("data/train.lat")]
lons = [np.fromstring(line, sep=' ') for line in open("data/train.lon")]
hex_view(lats, lons)
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
# from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.tensorboard import SummaryWriter
# from funcy import merge
import time, os, shutil, logging, h5py
# from collections import namedtuple
from model.t2vec import EncoderDecoder
from data_loader import DataLoader
from utils import *
from model.loss import *
writer = SummaryWriter()
PAD = 0
BOS = 1
EOS = 2We set up utility functions to initialize model weights, save checkpoints during training, and run validation. These ensure training is reproducible and models can be restored later.
The train() function loads train/val datasets, defines the loss functions (negative log-likelihood or KL-divergence), and builds the encoder-decoder model with its optimizer and scheduler. If a checkpoint exists, training resumes from where it left off; otherwise, parameters are freshly initialized.
Training uses two objectives:
Generative loss (predicting the next trajectory cell, like word prediction in NLP).
Discriminative loss (triplet margin loss, ensuring embeddings of similar trajectories are close while different ones are far apart).
These combined losses help the model learn not only to generate realistic trajectories but also to embed them in a useful vector space.
The loop runs over iterations, logging training progress, validating periodically, and saving checkpoints. A learning rate scheduler adjusts the optimizer based on validation loss, and early stopping prevents wasted computation when no improvements occur.
The test() function loads the best checkpoint, evaluates it on the test set, and reports average loss and perplexity. Perplexity is borrowed from NLP — lower values mean the model is more confident in predicting the next trajectory cell.
def init_parameters(model):
for p in model.parameters():
p.data.uniform_(-0.1, 0.1)
def savecheckpoint(state, is_best, args):
torch.save(state, args.checkpoint)
if is_best:
shutil.copyfile(args.checkpoint, os.path.join(args.data, 'best_model.pt'))
def validate(valData, model, lossF, args):
"""
valData (DataLoader)
"""
m0, m1 = model
## switch to evaluation mode
m0.eval()
m1.eval()
num_iteration = valData.size // args.batch
if valData.size % args.batch > 0: num_iteration += 1
total_genloss = 0
for iteration in range(num_iteration):
gendata = valData.getbatch_generative()
with torch.no_grad():
genloss = genLoss(gendata, m0, m1, lossF, args)
total_genloss += genloss.item() * gendata.trg.size(1)
## switch back to training mode
m0.train()
m1.train()
return total_genloss / valData.size
def train(args):
logging.basicConfig(filename=os.path.join(args.data, "training.log"), level=logging.INFO)
trainsrc = os.path.join(args.data, "train.src") # data path
traintrg = os.path.join(args.data, "train.trg")
trainlat = os.path.join(args.data, "train.lat")
trainlon = os.path.join(args.data, "train.lon")
# trainmta = os.path.join(args.data, "train.mta")
trainData = DataLoader(trainsrc, traintrg, trainlat, trainlon, args.batch, args.bucketsize)
print("Reading training data...")
trainData.load(args.max_num_line)
print("Allocation: {}".format(trainData.allocation))
print("Percent: {}".format(trainData.p))
valsrc = os.path.join(args.data, "val.src")
valtrg = os.path.join(args.data, "val.trg")
vallat = os.path.join(args.data, "val.lat")
vallon = os.path.join(args.data, "val.lon")
if os.path.isfile(valsrc) and os.path.isfile(valtrg):
valData = DataLoader(valsrc, valtrg, vallat, vallon, args.batch, args.bucketsize, validate=True)
print("Reading validation data...")
valData.load()
assert valData.size > 0, "Validation data size must be greater than 0"
print("Loaded validation data size {}".format(valData.size))
else:
print("No validation data found, training without validating...")
## create criterion, model, optimizer
if args.criterion_name == "NLL":
criterion = NLLcriterion(args.vocab_size)
lossF = lambda o, t: criterion(o, t)
else:
assert os.path.isfile(args.knearestvocabs),\
"{} does not exist".format(args.knearestvocabs)
print("Loading vocab distance file {}...".format(args.knearestvocabs))
with h5py.File(args.knearestvocabs, "r") as f:
V, D = f["V"][...], f["D"][...]
V, D = torch.LongTensor(V), torch.FloatTensor(D)
D = dist2weight(D, args.dist_decay_speed)
if args.cuda and torch.cuda.is_available():
V, D = V.cuda(), D.cuda()
criterion = KLDIVcriterion(args.vocab_size)
lossF = lambda o, t: KLDIVloss(o, t, criterion, V, D)
triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
m0 = EncoderDecoder(args.vocab_size,
args.embedding_size,
args.hidden_size,
args.num_layers,
args.dropout,
args.bidirectional)
m1 = nn.Sequential(nn.Linear(args.hidden_size, args.vocab_size),
nn.LogSoftmax(dim=1))
if args.cuda and torch.cuda.is_available():
print("=> training with GPU")
m0.cuda()
m1.cuda()
criterion.cuda()
#m0 = nn.DataParallel(m0, dim=1)
else:
print("=> training with CPU")
m0_optimizer = torch.optim.Adam(m0.parameters(), lr=args.learning_rate)
m1_optimizer = torch.optim.Adam(m1.parameters(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
m0_optimizer,
mode='min',
patience=args.lr_decay_patience,
min_lr=0,
verbose=True
)
## load model state and optmizer state
if os.path.isfile(args.checkpoint):
print("=> loading checkpoint '{}'".format(args.checkpoint))
logging.info("Restore training @ {}".format(time.ctime()))
checkpoint = torch.load(args.checkpoint)
args.start_iteration = checkpoint["iteration"]
best_prec_loss = checkpoint["best_prec_loss"]
m0.load_state_dict(checkpoint["m0"])
m1.load_state_dict(checkpoint["m1"])
m0_optimizer.load_state_dict(checkpoint["m0_optimizer"])
m1_optimizer.load_state_dict(checkpoint["m1_optimizer"])
else:
print("=> no checkpoint found at '{}'".format(args.checkpoint))
logging.info("Start training @ {}".format(time.ctime()))
best_prec_loss = float('inf')
args.start_iteration = 0
print("=> initializing the parameters...")
init_parameters(m0)
init_parameters(m1)
## here: load pretrained wrod (cell) embedding
num_iteration = 6700*128 // args.batch
print("Iteration starts at {} "
"and will end at {}".format(args.start_iteration, num_iteration-1))
no_improvement_count = 0
## training
for iteration in range(args.start_iteration, num_iteration):
try:
m0_optimizer.zero_grad()
m1_optimizer.zero_grad()
## generative loss
gendata = trainData.getbatch_generative()
genloss = genLoss(gendata, m0, m1, lossF, args)
## discriminative loss
disloss_cross, disloss_inner = 0, 0
if args.use_discriminative and iteration % 5 == 0:
a, p, n = trainData.getbatch_discriminative_cross()
disloss_cross = disLoss(a, p, n, m0, triplet_loss, args)
a, p, n = trainData.getbatch_discriminative_inner()
disloss_inner = disLoss(a, p, n, m0, triplet_loss, args)
loss = genloss + args.discriminative_w * (disloss_cross + disloss_inner)
# Add to tensorboard
writer.add_scalar('Loss/train', loss, iteration)
## compute the gradients
loss.backward()
## clip the gradients
clip_grad_norm_(m0.parameters(), args.max_grad_norm)
clip_grad_norm_(m1.parameters(), args.max_grad_norm)
## one step optimization
m0_optimizer.step()
m1_optimizer.step()
## average loss for one word
avg_genloss = genloss.item() / gendata.trg.size(0)
if iteration % args.print_freq == 0:
print("Iteration: {0:}\tGenerative Loss: {1:.3f}\t"\
"Discriminative Cross Loss: {2:.3f}\tDiscriminative Inner Loss: {3:.3f}"\
.format(iteration, avg_genloss, disloss_cross, disloss_inner))
if iteration % args.save_freq == 0 and iteration > 0:
prec_loss = validate(valData, (m0, m1), lossF, args)
# Add to tensorboard
writer.add_scalar('Loss/validation', prec_loss, iteration)
scheduler.step(prec_loss)
if prec_loss < best_prec_loss:
best_prec_loss = prec_loss
logging.info("Best model with loss {} at iteration {} @ {}"\
.format(best_prec_loss, iteration, time.ctime()))
is_best = True
no_improvement_count = 0
else:
is_best = False
no_improvement_count += 1
print("Saving the model at iteration {} validation loss {}"\
.format(iteration, prec_loss))
savecheckpoint({
"iteration": iteration,
"best_prec_loss": best_prec_loss,
"m0": m0.state_dict(),
"m1": m1.state_dict(),
"m0_optimizer": m0_optimizer.state_dict(),
"m1_optimizer": m1_optimizer.state_dict()
}, is_best, args)
# Early stopping if there is no improvement after a certain number of epochs
if no_improvement_count >= args.early_stopping_patience:
print('No improvement after {} iterations, early stopping triggered.'.format(args.early_stopping_patience))
break
except KeyboardInterrupt:
break
def test(args):
# load testing data
testsrc = os.path.join(args.data, "test.src")
testtrg = os.path.join(args.data, "test.trg")
testlat = os.path.join(args.data, "test.lat")
testlon = os.path.join(args.data, "test.lon")
if os.path.isfile(testsrc) and os.path.isfile(testtrg):
testData = DataLoader(testsrc, testtrg, testlat, testlon, args.batch, args.bucketsize, validate=True)
print("Reading testing data...")
testData.load()
assert testData.size > 0, "Testing data size must be greater than 0"
print("Loaded testing data size {}".format(testData.size))
else:
print("No testing data found, aborting test.")
return
# set up model
m0 = EncoderDecoder(args.vocab_size,
args.embedding_size,
args.hidden_size,
args.num_layers,
args.dropout,
args.bidirectional)
m1 = nn.Sequential(nn.Linear(args.hidden_size, args.vocab_size),
nn.LogSoftmax(dim=1))
# load best model state
best_model_path = 'data/best_model.pt'
if os.path.isfile(best_model_path):
print("=> loading checkpoint '{}'".format(args.checkpoint))
best_model = torch.load(best_model_path)
m0.load_state_dict(best_model["m0"])
m1.load_state_dict(best_model["m1"])
else:
print("Best model not found. Aborting test.")
return
m0.eval()
m1.eval()
# loss function
criterion = NLLcriterion(args.vocab_size)
lossF = lambda o, t: criterion(o, t)
# check device
if args.cuda and torch.cuda.is_available():
print("=> test with GPU")
m0.cuda()
m1.cuda()
criterion.cuda()
#m0 = nn.DataParallel(m0, dim=1)
else:
print("=> test with CPU")
num_iteration = (testData.size + args.batch - 1) // args.batch
total_genloss = 0
total_tokens = 0
with torch.no_grad():
for iter in range(num_iteration):
gendata = testData.getbatch_generative()
genloss = genLoss(gendata, m0, m1, lossF, args)
total_genloss += genloss.item() #* gendata.trg.size(1) # remove the multiplication
total_tokens += (gendata.trg != PAD).sum().item() # count non-pad tokens
print("Testing genloss at {} iteration is {}".format(iter, total_genloss))
avg_loss = total_genloss / total_tokens
perplexity = torch.exp(torch.tensor(avg_loss))
print(f"[Test] Avg Loss: {avg_loss:.4f} | Perplexity: {perplexity:.2f}")
ARGS
class Args:
data = 'data/'
checkpoint = 'data/checkpoint.pt'
vocab_size = len(h3_vocab)
embedding_size = 128
hidden_size = 128
num_layers = 1
dropout = 0.1
max_grad_norm = 1.0
learning_rate = 1e-2
lr_decay_patience = 20
early_stopping_patience = 50
cuda = torch.cuda.is_available()
bidirectional = True
batch = 16
num_epochs = 100
bucketsize = [(20,30),(30,30),(30,50),(50,50),(50,70),(70,70),(70,100),(100,100)]
criterion_name = "NLL"
use_discriminative = True
discriminative_w = 0.1
max_num_line = 200000
start_iteration = 0
generator_batch = 16
print_freq = 10
save_freq = 10
args = Args()train(args)Test
test(args)Results-
The generative model’s performance on the test dataset demonstrates remarkable accuracy in predicting vessel trajectories. Over the course of the evaluation, the cumulative generative loss across batches increased steadily, reflecting the accumulation of prediction errors over the sequences. When aggregated and normalized per token, the average loss was 0.2309, corresponding to a perplexity of approximately 1.26.
Perplexity is a standard measure in sequence modeling that quantifies how well a probabilistic model predicts a sequence of tokens. A perplexity close to 1 indicates near-deterministic prediction, meaning that the model assigns very high probability to the correct next token in the sequence. In the context of vessel trajectories, this result implies that the model is extremely confident and precise in forecasting the next H3 cell in a track, capturing the underlying spatial and temporal patterns in the data.
These results are particularly noteworthy because vessel movements are constrained by both geography and navigational behavior. The model effectively learns these patterns, predicting transitions between cells with minimal uncertainty. Achieving such a low perplexity confirms that the preprocessing pipeline, H3 cell encoding, and the sequence modeling architecture are all functioning harmoniously, enabling highly accurate trajectory modeling.
Overall, the evaluation demonstrates that the model not only generalizes well to unseen tracks but also reliably captures the deterministic structure of vessel movement, providing a robust foundation for downstream tasks such as trajectory prediction, anomaly detection, or maritime route analysis.
Testing genloss at 0 iteration is 46.40993881225586
Testing genloss at 1 iteration is 83.17555618286133
Testing genloss at 2 iteration is 122.76013565063477
Testing genloss at 3 iteration is 167.81907272338867
Testing genloss at 4 iteration is 223.75146102905273
Testing genloss at 5 iteration is 287.765926361084
Testing genloss at 6 iteration is 328.6252250671387
Testing genloss at 7 iteration is 394.95031356811523
Testing genloss at 8 iteration is 459.411678314209
Testing genloss at 9 iteration is 557.8198432922363
Testing genloss at 10 iteration is 724.5464973449707
Testing genloss at 11 iteration is 876.1395149230957
Testing genloss at 12 iteration is 1020.6461372375488
Testing genloss at 13 iteration is 1277.3499336242676
Testing genloss at 14 iteration is 1416.0101203918457
Testing genloss at 15 iteration is 1742.3399543762207
Testing genloss at 16 iteration is 2101.4984016418457
Testing genloss at 17 iteration is 2319.603458404541
[Test] Avg Loss: 0.2309 | Perplexity: 1.26In this page, we will see how we can use AISDb to discretize AIS tracks to hexagons.
Hexagonal Geospatial Indexing (H3): Uber’s hierarchical hexagonal geospatial indexing system, partitions the Earth into a multi-resolution hexagonal grid. Its key advantage over square grids is the “one-distance rule,” where all neighbors of a hexagon lie at comparable step distances.
As illustrated in the figure above, this uniformity removes the diagonal-versus-edge ambiguity present in square lattices. For maritime work, hexagons are great because they reduce directional bias and make neighborhood queries and aggregation intuitive.
Note: H3 indexes are 64-bit IDs typically shown as hex strings like “860e4d31fffffff.”
The code below provides a complete example of how to connect to a database of AIS data using AISDb and generate the corresponding H3 index for each data point.
import aisdb
from aisdb import DBQuery
from aisdb.database.dbconn import PostgresDBConn
from datetime import datetime, timedelta
from aisdb.discretize.h3 import Discretizer # main import to convert lat/lon to H3 indexes
# >>> PostgreSQL connection details (replace placeholders or use environment variables) <<<
db_user = '<>' # PostgreSQL username
db_dbname = '<>' # PostgreSQL database/schema name
db_password = '<>' # PostgreSQL password
db_hostaddr = '127.0.0.1' # PostgreSQL host (localhost shown)
# Create a database connection handle for AISDB to use
dbconn = PostgresDBConn(
port=5555, # PostgreSQL port (5432 is default; 5555 here is just an example)
user=db_user, # username for authentication
dbname=db_dbname, # database/schema to connect to
host=db_hostaddr, # host address or DNS name
password=db_password, # password for authentication
)
# ------------------------------
# Define the spatial and temporal query window
# Note: bbox is [xmin, ymin, xmax, ymax] in lon/lat; variables below help readability
xmin, ymin, xmax, ymax = -70, 45, -58, 53
gulf_bbox = [xmin, xmax, ymin, ymax] # optional helper; not used directly below
start_time = datetime(2023, 8, 1) # query start (inclusive)
end_time = datetime(2023, 8, 2) # query end (exclusive or inclusive per DB settings)
# Build a query that streams AIS rows in the time window and bounding box
qry = DBQuery(
dbconn=dbconn,
start=start_time, end=end_time,
xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax,
# Callback filters rows by time, bbox, and ensures MMSI validity (helps remove junk)
callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi
)
# Prepare containers/generators for streamed processing (memory-efficient)
ais_tracks = [] # placeholder list if you want to collect tracks (unused below)
rowgen = qry.gen_qry() # generator that yields raw AIS rows from the database
# ------------------------------
# Instantiate the H3 Discretizer at a chosen resolution
# Resolution 6 ≈ regional scale hexagons; increase for finer grids, decrease for coarser grids
# Note: variable name 'descritizer' is kept to match the original snippet (typo is harmless)
descritizer = Discretizer(resolution=6)
# Build tracks from rows; decimate=True reduces oversampling and speeds up processing
tracks = aisdb.track_gen.TrackGen(rowgen, decimate=True)
# Optionally split long tracks into time-bounded segments (e.g., 4-week chunks)
# Useful for chunked processing or time-based aggregation; not used further in this snippet
tracks_segment = aisdb.track_gen.split_timedelta(
tracks,
timedelta(weeks=4)
)
# Discretize each track: adds an H3 index array aligned with lat/lon points for that track
# Each yielded track will have keys like 'lat', 'lon', and 'h3_index'
tracks_with_indexes = descritizer.yield_tracks_discretized_by_indexes(tracks)
# Example (optional) usage:
# for t in tracks_with_indexes:
# # Access the first point's H3 index for this track
# print(t['mmsi'], t['timestamp'][0], t['lat'][0], t['lon'][0], t['h3_index'][0])
# break
# Output: H3 Index for lat 50.003334045410156, lon -66.76000213623047: 860e4d31fffffffRefer to the example notebook here: https://github.com/AISViz/AISdb/blob/master/examples/discretize.ipynb
Traditional Seq2Seq LSTM models have long been the workhorse for trajectory forecasting. They excel at learning temporal dependencies in AIS data, and with attention mechanisms they can capture complex nonlinear patterns over long histories. However, they remain purely statistical. This means the model can generate plausible-looking trajectories from a data perspective but with no guarantee that those predictions respect the underlying physics of vessel motion. In practice, this often manifests as sharp turns, unrealistic accelerations, or trajectories that deviate significantly when the model faces sparse or noisy data.
The NPINN-based approach directly addresses these shortcomings. By embedding smoothness and kinematic penalties into training, it enforces constraints on velocity and acceleration while still benefiting from the representational power of deep sequence models. Instead of simply fitting residuals between past and future positions, NPINN ensures that predictions evolve in ways consistent with how vessels actually move in the physical world. This leads to more reliable extrapolation, especially in data-scarce regions or unusual navigation scenarios.
The first step in building a trajectory learning pipeline is preprocessing AIS tracks into a model-friendly format. Raw AIS messages are noisy, irregularly sampled, and inconsistent across vessels, so we need to enforce structure before feeding them into neural networks. The function below does several things in sequence:
Data cleaning – removes spurious pings based on unrealistic speeds, encodes great-circle distances, and interpolates trajectories at fixed 5-minute intervals.
Track filtering – groups data by vessel (MMSI) and keeps only sufficiently long tracks to ensure stable training samples.
Feature extraction – converts lat/lon into projected coordinates (x, y), adds speed over ground (sog), and represents course over ground (cog) as sine/cosine to avoid angular discontinuities.
Delta computation – calculates dx and dy between consecutive timestamps, capturing local motion dynamics.
Scaling – applies RobustScaler to normalize features and deltas while being resilient to outliers (common in AIS data).
The result is a clean, scaled DataFrame where each row represents a vessel state at a timestamp, enriched with both absolute position features and relative motion features.
from sklearn.preprocessing import RobustScaler
def preprocess_aisdb_tracks(tracks_gen, proj,
sog_scaler=None,
feature_scaler=None,
delta_scaler=None,
fit_scaler=False):
# --- AISdb cleaning ---
tracks_gen = aisdb.remove_pings_wrt_speed(tracks_gen, 0.1)
tracks_gen = aisdb.encode_greatcircledistance(
tracks_gen,
distance_threshold=50000,
minscore=1e-5,
speed_threshold=50
)
tracks_gen = aisdb.interp_time(tracks_gen, step=timedelta(minutes=5))
# --- collect tracks ---
tracks = list(tracks_gen)
# --- group by MMSI ---
tracks_by_mmsi = defaultdict(list)
for track in tracks:
tracks_by_mmsi[track['mmsi']].append(track)
# --- keep only long-enough tracks ---
valid_tracks = []
for mmsi, mmsi_tracks in tracks_by_mmsi.items():
if all(len(t['time']) >= 100 for t in mmsi_tracks):
valid_tracks.extend(mmsi_tracks)
# --- flatten into dataframe ---
rows = []
for track in valid_tracks:
mmsi = track['mmsi']
sog = track.get('sog', [np.nan]*len(track['time']))
cog = track.get('cog', [np.nan]*len(track['time']))
for i in range(len(track['time'])):
x, y = proj(track['lon'][i], track['lat'][i])
cog_rad = np.radians(cog[i]) if cog[i] is not None else np.nan
rows.append({
'mmsi': mmsi,
'x': x,
'y': y,
'sog': sog[i],
'cog_sin': np.sin(cog_rad) if not np.isnan(cog_rad) else np.nan,
'cog_cos': np.cos(cog_rad) if not np.isnan(cog_rad) else np.nan,
'timestamp': pd.to_datetime(track['time'][i], errors='coerce')
})
df = pd.DataFrame(rows)
# --- clean NaNs ---
df = df.replace([np.inf, -np.inf], np.nan)
df = df.dropna(subset=['x', 'y', 'sog', 'cog_sin', 'cog_cos'])
# --- compute deltas per MMSI ---
df = df.sort_values(["mmsi", "timestamp"])
df["dx"] = df.groupby("mmsi")["x"].diff().fillna(0)
df["dy"] = df.groupby("mmsi")["y"].diff().fillna(0)
# --- scale features ---
feature_cols = ['x', 'y', 'sog', 'cog_sin', 'cog_cos']
delta_cols = ['dx', 'dy']
if fit_scaler:
# sog
sog_scaler = RobustScaler()
df['sog_scaled'] = sog_scaler.fit_transform(df[['sog']])
# absolute features
feature_scaler = RobustScaler()
df[feature_cols] = feature_scaler.fit_transform(df[feature_cols])
# deltas
delta_scaler = RobustScaler()
df[delta_cols] = delta_scaler.fit_transform(df[delta_cols])
else:
df['sog_scaled'] = sog_scaler.transform(df[['sog']])
df[feature_cols] = feature_scaler.transform(df[feature_cols])
df[delta_cols] = delta_scaler.transform(df[delta_cols])
return df, sog_scaler, feature_scaler, delta_scalerWe first query the AIS database for the training and testing periods and geographic bounds, producing generators of raw vessel tracks. These raw tracks are then preprocessed using preprocess_aisdb_tracks, which cleans the data, computes relative motion (dx, dy), scales features, and outputs a ready-to-use DataFrame. Training data fits new scalers, while test data is transformed using the same scalers to ensure consistency.
train_qry = aisdb.DBQuery(dbconn=dbconn, callback=in_timerange,
start=START_DATE, end=END_DATE,
xmin=XMIN, xmax=XMAX, ymin=YMIN, ymax=YMAX)
train_gen = TrackGen(train_qry.gen_qry(verbose=True), decimate=False)
test_qry = aisdb.DBQuery(dbconn=dbconn, callback=in_timerange,
start=TEST_START_DATE, end=TEST_END_DATE,
xmin=XMIN, xmax=XMAX, ymin=YMIN, ymax=YMAX)
test_gen = TrackGen(test_qry.gen_qry(verbose=True), decimate=False)
# --- Preprocess ---
train_df, sog_scaler, feature_scaler, delta_scaler = preprocess_aisdb_tracks(
train_gen, proj, fit_scaler=True
)
test_df, _, _, _ = preprocess_aisdb_tracks(
test_gen,
proj,
sog_scaler=sog_scaler,
feature_scaler=feature_scaler,
delta_scaler=delta_scaler,
fit_scaler=False
)The create_sequences function transforms the preprocessed track data into supervised sequences suitable for model training. For each vessel, it slides a fixed-size window over the time series, building input sequences of past absolute features (x, y, dx, dy, cog_sin, cog_cos, sog_scaled) and target sequences of future residual movements (dx, dy). Using this, the dataset is split into training, validation, and test sets, with each set containing sequences ready for direct input into a trajectory prediction model.
def create_sequences(df, features, input_size=80, output_size=2, step=1):
"""
Build sequences:
X: past window of absolute features (x, y, dx, dy, cog_sin, cog_cos, sog_scaled)
Y: future residuals (dx, dy)
"""
X_list, Y_list = [], []
for mmsi in df['mmsi'].unique():
sub = df[df['mmsi'] == mmsi].sort_values('timestamp').copy()
# build numpy arrays
feat_arr = sub[features].to_numpy()
dxdy_arr = sub[['dx', 'dy']].to_numpy() # residuals already scaled
for i in range(0, len(sub) - input_size - output_size + 1, step):
# input sequence is absolute features
X_list.append(feat_arr[i : i + input_size])
# output sequence is residuals immediately after
Y_list.append(dxdy_arr[i + input_size : i + input_size + output_size])
return torch.tensor(X_list, dtype=torch.float32), torch.tensor(Y_list, dtype=torch.float32)
features = ['x', 'y', 'dx', 'dy', 'cog_sin', 'cog_cos', 'sog_scaled']
mmsis = train_df['mmsi'].unique()
train_mmsi, val_mmsi = train_test_split(mmsis, test_size=0.2, random_state=42, shuffle=True)
train_X, train_Y = create_sequences(train_df[train_df['mmsi'].isin(train_mmsi)], features)
val_X, val_Y = create_sequences(train_df[train_df['mmsi'].isin(val_mmsi)], features)
test_X, test_Y = create_sequences(test_df, features)This block saves all processed data and supporting objects needed for training and evaluation. The preprocessed input and target sequences for training, validation, and testing are serialized using PyTorch (datasets_npin.pt). The fitted scalers for features, speed, and residuals are saved with joblib to ensure consistent scaling during inference. Finally, the projection parameters used to convert geographic coordinates to UTM are stored in JSON, allowing consistent coordinate transformations later.
import torch
import joblib
import json
# --- save datasets ---
torch.save({
'train_X': train_X, 'train_Y': train_Y,
'val_X': val_X, 'val_Y': val_Y,
'test_X': test_X, 'test_Y': test_Y
}, 'datasets_npin.pt')
# --- save scalers ---
joblib.dump(feature_scaler, "npinn_feature_scaler.pkl")
joblib.dump(sog_scaler, "npinn_sog_scaler.pkl")
joblib.dump(delta_scaler, "npinn_delta_scaler.pkl") # NEW
# --- save projection parameters ---
proj_params = {'proj': 'utm', 'zone': 20, 'ellps': 'WGS84'}
with open("npinn_proj_params.json", "w") as f:
json.dump(proj_params, f)
data = torch.load('datasets_npin.pt')
import torch
import joblib
import json
# scalers
feature_scaler = joblib.load("npinn_feature_scaler.pkl")
sog_scaler = joblib.load("npinn_sog_scaler.pkl")
delta_scaler = joblib.load("npinn_delta_scaler.pkl") # NEW
# projection
with open("npinn_proj_params.json", "r") as f:
proj_params = json.load(f)
proj = pyproj.Proj(**proj_params)
train_ds = TensorDataset(data['train_X'], data['train_Y'])
val_ds = TensorDataset(data['val_X'], data['val_Y'])
test_ds = TensorDataset(data['test_X'], data['test_Y'])
batch_size = 64
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
test_dl = DataLoader(test_ds, batch_size=batch_size)Seq2SeqLSTM is the sequence-to-sequence backbone used in the NPINN-based trajectory prediction framework. Within NPINN, the encoder LSTM processes past vessel observations (positions, speed, and course) to produce a hidden representation that captures motion dynamics. The decoder LSTMCell predicts future residuals in x and y, with an attention mechanism that selectively focuses on relevant past information at each step. Predicted residuals are added to the last observed position to reconstruct absolute trajectories.
This setup enables NPINN to generate smooth, physically consistent multi-step vessel trajectories, leveraging both historical motion patterns and learned dynamics constraints.
class Seq2SeqLSTM(nn.Module):
def __init__(self, input_size, hidden_size, input_steps, output_steps):
super().__init__()
self.input_steps = input_steps
self.output_steps = output_steps
# Encoder
self.encoder = nn.LSTM(input_size, hidden_size, num_layers=2, dropout=0.3, batch_first=True)
# Decoder
self.decoder = nn.LSTMCell(input_size, hidden_size)
self.attn = nn.Linear(hidden_size * 2, input_steps)
self.attn_combine = nn.Linear(hidden_size + input_size, input_size)
# Output only x,y residuals (added to last observed pos)
self.output_layer = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Linear(hidden_size // 2, 2)
)
def forward(self, x, target_seq=None, teacher_forcing_ratio=0.5):
batch_size = x.size(0)
encoder_outputs, (h, c) = self.encoder(x)
h, c = h[-1], c[-1]
last_obs = x[:, -1, :2] # last observed absolute x,y
decoder_input = x[:, -1, :] # full feature vector
outputs = []
for t in range(self.output_steps):
attn_weights = torch.softmax(self.attn(torch.cat((h, c), dim=1)), dim=1)
context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)
dec_in = torch.cat((decoder_input, context), dim=1)
dec_in = self.attn_combine(dec_in)
h, c = self.decoder(dec_in, (h, c))
residual_xy = self.output_layer(h)
# accumulate into absolute xy
out_xy = residual_xy + last_obs
outputs.append(out_xy.unsqueeze(1))
# teacher forcing
if self.training and target_seq is not None and t < target_seq.size(1) and random.random() < teacher_forcing_ratio:
decoder_input = torch.cat([target_seq[:, t, :2], decoder_input[:, 2:]], dim=1)
last_obs = target_seq[:, t, :2]
else:
decoder_input = torch.cat([out_xy, decoder_input[:, 2:]], dim=1)
last_obs = out_xy
return torch.cat(outputs, dim=1) # (batch, output_steps, 2)This training loop implements NPINN-based trajectory learning, combining data fidelity with physics-inspired smoothness constraints. The weighted_coord_loss enforces accurate prediction of future x, y positions, while xy_npinn_smoothness_loss encourages smooth velocity and acceleration profiles, reflecting realistic vessel motion.
By integrating these two objectives, NPINN learns trajectories that are both close to observed data and physically plausible, with the smoothness weight gradually decaying during training to balance learning accuracy with dynamic consistency. Validation is performed each epoch to ensure generalization, and the best-performing model is saved. This approach differentiates NPINN from standard Seq2Seq training by explicitly incorporating motion dynamics into the loss, rather than relying purely on sequence prediction.
def weighted_coord_loss(pred, target, coord_weight=5.0, reduction='mean'):
return F.smooth_l1_loss(pred, target, reduction=reduction)
def xy_npinn_smoothness_loss(seq_full, coord_min=None, coord_max=None):
"""
NPINN-inspired smoothness penalty on xy coordinates
seq_full: [B, T, 2]
"""
xy = seq_full[..., :2]
if coord_min is not None and coord_max is not None:
xy_norm = (xy - coord_min) / (coord_max - coord_min + 1e-8)
xy_norm = 2 * (xy_norm - 0.5) # [-1,1]
else:
xy_norm = xy
v = xy_norm[:, 1:, :] - xy_norm[:, :-1, :]
a = v[:, 1:, :] - v[:, :-1, :]
return (v**2).mean() * 0.05 + (a**2).mean() * 0.5
def train_model(model, loader, val_dl, optimizer, device, epochs=50,
smooth_w_init=1e-3, coord_min=None, coord_max=None):
best_loss = float('inf')
best_state = None
for epoch in range(epochs):
model.train()
total_loss = total_data_loss = total_smooth_loss = 0.0
for batch_x, batch_y in loader:
batch_x = batch_x.to(device) # [B, T_in, F]
batch_y = batch_y.to(device) # [B, T_out, 2] absolute x,y
optimizer.zero_grad()
pred_xy = model(batch_x, target_seq=batch_y, teacher_forcing_ratio=0.5)
# Data loss: directly on absolute x,y
loss_data = weighted_coord_loss(pred_xy, batch_y)
# Smoothness loss: encourage smooth xy trajectories
y_start = batch_x[:, :, :2]
full_seq = torch.cat([y_start, pred_xy], dim=1) # observed + predicted
loss_smooth = xy_npinn_smoothness_loss(full_seq, coord_min, coord_max)
smooth_weight = smooth_w_init * max(0.1, 1.0 - epoch / 30.0)
loss = loss_data + smooth_weight * loss_smooth
loss.backward()
optimizer.step()
total_loss += loss.item()
total_data_loss += loss_data.item()
total_smooth_loss += loss_smooth.item()
avg_loss = total_loss / len(loader)
print(f"Epoch {epoch+1} | Total: {avg_loss:.6f} | Data: {total_data_loss/len(loader):.6f} | Smooth: {total_smooth_loss/len(loader):.6f}")
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
for xb, yb in val_dl:
xb, yb = xb.to(device), yb.to(device)
pred_xy = model(xb, target_seq=yb, teacher_forcing_ratio=0.0)
data_loss = weighted_coord_loss(pred_xy, yb)
full_seq = torch.cat([xb[..., :2], pred_xy], dim=1)
loss_smooth = xy_npinn_smoothness_loss(full_seq, coord_min, coord_max)
val_loss += (data_loss + smooth_weight * loss_smooth).item()
val_loss /= len(val_dl)
print(f" Val Loss: {val_loss:.6f}")
if val_loss < best_loss:
best_loss = val_loss
best_state = model.state_dict()
if best_state is not None:
torch.save(best_state, "best_model_NPINN.pth")
print("Best model saved")This training sets a fixed random seed for reproducibility (torch, numpy, random) and enables deterministic CuDNN behavior. It initializes a Seq2SeqLSTM NPINN model to predict future vessel trajectory residuals from past sequences of absolute features. Global min/max of xy coordinates are computed from the training set for NPINN’s smoothness loss normalization. The model is trained on GPU if available using Adam, combining data loss on predicted xy positions with a physics-inspired smoothness penalty to enforce realistic, physically plausible trajectories.
import torch
import numpy as np
import random
from torch import nn
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_X = data['train_X']
input_steps = 80
output_steps = 2
# Collapse batch and time dims to compute global min/max for each feature
flat_train_X = train_X.view(-1, train_X.shape[-1]) # shape: [N*T, F]
x_min, x_max = flat_train_X[:, 0].min().item(), flat_train_X[:, 0].max().item()
y_min, y_max = flat_train_X[:, 1].min().item(), flat_train_X[:, 1].max().item()
cog_sin_min, cog_sin_max = flat_train_X[:, 2].min().item(), flat_train_X[:, 2].max().item()
cog_cos_min, cog_cos_max = flat_train_X[:, 3].min().item(), flat_train_X[:, 3].max().item()
sog_min, sog_max = flat_train_X[:, 4].min().item(), flat_train_X[:, 4].max().item()
coord_min = torch.tensor([x_min, y_min], device=device)
coord_max = torch.tensor([x_max, y_max], device=device)
# Model setup
input_size = 7
hidden_size = 64
model = Seq2SeqLSTM(input_size, hidden_size, input_steps, output_steps).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# coord_min/max only for xy
coord_min = torch.tensor([x_min, y_min], device=device)
coord_max = torch.tensor([x_max, y_max], device=device)
train_model(model, train_dl, val_dl, optimizer, device,
coord_min=coord_min, coord_max=coord_max)This snippet sets up the environment for inference or evaluation of the NPINN Seq2Seq model:
Chooses GPU if available.
Loads preprocessed datasets (train_X/Y, val_X/Y, test_X/Y).
Loads the saved RobustScalers for features, SOG, and deltas to match preprocessing during training.
Loads projection parameters to convert lon/lat to projected coordinates consistently.
Rebuilds the same Seq2SeqLSTM NPINN model used during training and loads the best saved weights.
Puts the model in evaluation mode, ready for predicting future vessel trajectories.
Essentially, this is the full recovery pipeline for NPINN inference, ensuring consistency with training preprocessing, scaling, and projection.
import torch
import joblib
import json
import pyproj
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
import cartopy.crs as ccrs
import cartopy.feature as cfeature
# --- device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# recover arrays (assumes you already loaded `data`)
train_X, train_Y = data['train_X'], data['train_Y']
val_X, val_Y = data['val_X'], data['val_Y']
test_X, test_Y = data['test_X'], data['test_Y']
# --- load scalers (must match what you saved during preprocessing) ---
feature_scaler = joblib.load("npinn_feature_scaler.pkl")
sog_scaler = joblib.load("npinn_sog_scaler.pkl")
delta_scaler = joblib.load("npinn_delta_scaler.pkl") # for dx,dy
# if you named them differently, change the filenames above
# --- load projection params and build proj ---
with open("proj_params.json", "r") as f:
proj_params = json.load(f)
proj = pyproj.Proj(**proj_params)
# --- rebuild & load model ---
input_size = train_X.shape[2]
input_steps = train_X.shape[1]
output_steps = train_Y.shape[1]
hidden_size = 64
num_layers = 2
best_model = Seq2SeqLSTM(
input_size=input_size,
hidden_size=hidden_size,
input_steps=input_steps,
output_steps=output_steps,
).to(device)
best_model.load_state_dict(torch.load("best_model_NPINN_absXY.pth", map_location=device))
best_model.eval()
This code provides essential postprocessing and geometric utilities for working with NPINN outputs and AIS trajectory data. The inverse_dxdy_np function is designed to convert scaled residuals (dx, dy) back into real-world units (meters) using a previously fitted scaler. It handles both 1D and 2D inputs, making it suitable for batch or single-step predictions. This is particularly useful for interpreting NPINN predictions in absolute physical units rather than in the normalized or scaled feature space, allowing for meaningful evaluation of the model’s accuracy in real-world terms. Using this, the code also computes the standard deviation of the residuals across the training dataset, providing a quantitative measure of typical displacement magnitudes along the x and y axes.
The code also includes geometry-related helpers to analyze trajectories in geospatial terms. The haversine function calculates the geodesic distance between longitude/latitude points in meters using the haversine formula, with safeguards for numerical stability and invalid inputs. Building on this, the trajectory_length function computes the total length of a vessel’s trajectory, summing distances between consecutive points while handling incomplete or non-finite data gracefully. Together, these utilities allow NPINN outputs to be mapped back to real-world coordinates, facilitate evaluation of trajectory smoothness and accuracy, and provide interpretable metrics for model validation and downstream analysis.
# ---------------- helper inverse/scaling utilities ----------------
def inverse_dxdy_np(dxdy_scaled, scaler):
"""
Invert scaled residuals (dx, dy) back to meters.
dxdy_scaled: (..., 2) scaled
scaler: RobustScaler/StandardScaler/MinMaxScaler fitted on residuals
"""
dxdy_scaled = np.asarray(dxdy_scaled, dtype=float)
if dxdy_scaled.ndim == 1:
dxdy_scaled = dxdy_scaled[None, :]
n_samples = dxdy_scaled.shape[0]
n_features = scaler.scale_.shape[0] if hasattr(scaler, "scale_") else scaler.center_.shape[0]
full_scaled = np.zeros((n_samples, n_features))
full_scaled[:, :2] = dxdy_scaled
if hasattr(scaler, "mean_"):
center = scaler.mean_
scale = scaler.scale_
elif hasattr(scaler, "center_"):
center = scaler.center_
scale = scaler.scale_
else:
raise ValueError(f"Scaler type {type(scaler)} not supported")
full = full_scaled * scale + center
return full[:, :2] if dxdy_scaled.shape[0] > 1 else full[0, :2]
# ---------------- compute residual std (meters) correctly ----------------
# train_Y contains scaled residuals (dx,dy) per your preprocessing.
all_resids_scaled = train_Y.reshape(-1, 2) # [sum_T, 2]
all_resids_m = inverse_dxdy_np(all_resids_scaled, delta_scaler) # meters
residual_std = np.std(all_resids_m, axis=0)
print("Computed residual_std (meters):", residual_std)
# ---------------- geometry helpers ----------------
def haversine(lon1, lat1, lon2, lat2):
"""Distance (m) between lon/lat points using haversine formula; handles arrays."""
R = 6371000.0
lon1 = np.asarray(lon1, dtype=float)
lat1 = np.asarray(lat1, dtype=float)
lon2 = np.asarray(lon2, dtype=float)
lat2 = np.asarray(lat2, dtype=float)
# if any entry is non-finite, result will be nan — we'll guard upstream
lon1, lat1, lon2, lat2 = map(np.radians, [lon1, lat1, lon2, lat2])
dlon = lon2 - lon1
dlat = lat2 - lat1
a = np.sin(dlat/2.0)**2 + np.cos(lat1)*np.cos(lat2)*np.sin(dlon/2.0)**2
# numerical stability: clip inside sqrt
a = np.clip(a, 0.0, 1.0)
return 2 * R * np.arcsin(np.sqrt(a))
def trajectory_length(lons, lats):
lons = np.asarray(lons, dtype=float)
lats = np.asarray(lats, dtype=float)
if lons.size < 2:
return 0.0
# guard non-finite
if not (np.isfinite(lons).all() and np.isfinite(lats).all()):
return float("nan")
return np.sum(haversine(lons[:-1], lats[:-1], lons[1:], lats[1:]))This function evaluate_with_errors is designed to evaluate NPINN trajectory predictions in a geospatial context and optionally visualize them. It takes a trained model, a test DataLoader, coordinate projection, scalers, and device information. For each batch, it reconstructs the predicted trajectories from residuals (dx, dy), inverts the scaling back to meters, and converts them to absolute positions starting from the last observed point. Different decoding modes (cumsum, independent, stdonly) allow flexibility in how residuals are integrated into absolute trajectories, and it handles cases where the first residual is effectively a duplicate of the last input.
The evaluation computes per-timestep errors in meters using the haversine formula and tracks differences in trajectory lengths. All errors are summarized with mean and median statistics across the prediction horizon. When plot_map=True, the function generates separate maps for each trajectory, overlaying the true (green) and predicted (red dashed) paths, giving a clear visual inspection of the model’s performance. This approach is directly aligned with NPINN, as it evaluates predictions in physical units and emphasizes smooth, physically plausible trajectory reconstructions.
def evaluate_with_errors(
model,
test_dl,
proj,
feature_scaler,
delta_scaler,
device,
num_batches=None, # None = use full dataset
dup_tol: float = 1e-4,
outputs_are_residual_xy: bool = True,
residual_decode_mode: str = "cumsum", # "cumsum", "independent", "stdonly"
residual_std: np.ndarray = None,
plot_map: bool = True # <--- PLOT ALL TRAJECTORIES
):
"""
Evaluate model trajectory predictions and report errors in meters.
Optionally plot all trajectories on a map.
"""
model.eval()
errors_all = []
length_diffs = []
bad_count = 0
# store all trajectories
all_real = []
all_pred = []
with torch.no_grad():
batches = 0
for xb, yb in test_dl:
xb, yb = xb.to(device), yb.to(device)
pred = model(xb, teacher_forcing_ratio=0.0) # [B, T_out, F]
# first sample of the batch
input_seq = xb[0].cpu().numpy()
real_seq = yb[0].cpu().numpy()
pred_seq = pred[0].cpu().numpy()
# Extract dx, dy residuals
pred_resid_s = pred_seq[:, :2]
real_resid_s = real_seq[:, :2]
# Invert residuals to meters
pred_resid_m = inverse_dxdy_np(pred_resid_s, delta_scaler)
real_resid_m = inverse_dxdy_np(real_resid_s, delta_scaler)
# Use last observed absolute position as starting point (meters)
last_obs_xy_m = inverse_xy_only_np(input_seq[-1, :2], feature_scaler)
# Reconstruct absolute positions
if residual_decode_mode == "cumsum":
pred_xy_m = np.cumsum(pred_resid_m, axis=0) + last_obs_xy_m
real_xy_m = np.cumsum(real_resid_m, axis=0) + last_obs_xy_m
elif residual_decode_mode == "independent":
pred_xy_m = pred_resid_m + last_obs_xy_m
real_xy_m = real_resid_m + last_obs_xy_m
elif residual_decode_mode == "stdonly":
if residual_std is None:
raise ValueError("residual_std must be provided for 'stdonly' mode")
noise = np.random.randn(*pred_resid_m.shape) * residual_std
pred_xy_m = np.cumsum(noise, axis=0) + last_obs_xy_m
real_xy_m = np.cumsum(real_resid_m, axis=0) + last_obs_xy_m
else:
raise ValueError(f"Unknown residual_decode_mode: {residual_decode_mode}")
# Remove first target if duplicates last input
if np.allclose(real_resid_m[0], 0, atol=dup_tol):
real_xy_m = real_xy_m[1:]
pred_xy_m = pred_xy_m[1:]
# align horizon
min_len = min(len(pred_xy_m), len(real_xy_m))
if min_len == 0:
bad_count += 1
continue
pred_xy_m = pred_xy_m[:min_len]
real_xy_m = real_xy_m[:min_len]
# project to lon/lat
lon_real, lat_real = proj(real_xy_m[:,0], real_xy_m[:,1], inverse=True)
lon_pred, lat_pred = proj(pred_xy_m[:,0], pred_xy_m[:,1], inverse=True)
all_real.append((lon_real, lat_real))
all_pred.append((lon_pred, lat_pred))
# compute per-timestep errors
errors = haversine(lon_real, lat_real, lon_pred, lat_pred)
errors_all.append(errors)
# trajectory length diff
real_len = trajectory_length(lon_real, lat_real)
pred_len = trajectory_length(lon_pred, lat_pred)
length_diffs.append(abs(real_len - pred_len))
print(f"Trajectory length (true): {real_len:.2f} m | pred: {pred_len:.2f} m | diff: {abs(real_len - pred_len):.2f} m")
batches += 1
if num_batches is not None and batches >= num_batches:
break
# summary
if len(errors_all) == 0:
print("No valid samples evaluated. Bad count:", bad_count)
return
max_len = max(len(e) for e in errors_all)
errors_padded = np.full((len(errors_all), max_len), np.nan)
for i, e in enumerate(errors_all):
errors_padded[i, :len(e)] = e
mean_per_t = np.nanmean(errors_padded, axis=0)
print("\n=== Summary (meters) ===")
for t, v in enumerate(mean_per_t):
if not np.isnan(v):
print(f"t={t} mean error: {v:.2f} m")
print(f"Mean over horizon: {np.nanmean(errors_padded):.2f} m | Median: {np.nanmedian(errors_padded):.2f} m")
print(f"Mean trajectory length diff: {np.mean(length_diffs):.2f} m | Median: {np.median(length_diffs):.2f} m")
print("Bad / skipped samples:", bad_count)
# --- plot all trajectories ---
# --- plot each trajectory separately ---
if plot_map and len(all_real) > 0:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
for idx, ((lon_r, lat_r), (lon_p, lat_p)) in enumerate(zip(all_real, all_pred)):
fig = plt.figure(figsize=(10, 8))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.add_feature(cfeature.LAND)
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.BORDERS, linestyle=':')
ax.plot(lon_r, lat_r, color='green', linewidth=2, label="True")
ax.plot(lon_p, lat_p, color='red', linestyle='--', linewidth=2, label="Predicted")
ax.legend()
ax.set_title(f"Trajectory {idx+1}: True vs Predicted")
plt.show()
Trajectory 2
True length: 521.39 m
Predicted length: 508.66 m
Difference: 12.73 m
The predicted path is roughly parallel to the true path but slightly offset. The predicted trajectory underestimates the total path length by ~2.4%, which is a small but noticeable error. The smoothness of the red dashed line indicates that the model is generating physically plausible, consistent trajectories.
Trajectory 5
True length: 188.76 m
Predicted length: 206.01 m
Difference: 17.25 m
Here, the predicted trajectory slightly overestimates the total distance (~9%), again with a smooth but slightly offset path relative to the true trajectory. The model captures the overall direction but has some scaling error in step lengths or residuals.
Trajectory length (true): 521.39 m | pred: 508.66 m | diff: 12.73 mTrajectory length (true): 188.76 m | pred: 206.01 m | diff: 17.25 mt=0 mean error: 45.03 m The first prediction step already has a ~45 m average discrepancy from the true position, which is common since the model starts accumulating error immediately after the last observed point.
t=1 mean error: 80.80 m Error grows with the horizon, reflecting cumulative effects of residual inaccuracies.
Mean over horizon: 62.92 m | Median: 61.72 m On average, predictions are within ~60–63 m of the true trajectory at any given time step. Median being close to mean suggests a fairly symmetric error distribution without extreme outliers.
Mean trajectory length difference: 11.82 m | Median: 12.73 m Overall, the predicted trajectories’ total lengths are very close to the true lengths, typically within ~12 m, which is less than 3% relative error for most trajectories.
The model captures trajectory trends well but shows small offsets in absolute positions.
Errors grow with horizon, which is typical for sequence prediction models using residuals.
Smoothness is maintained (no erratic jumps), indicating that the NPINN smoothness regularization is effective.
Overall, this is a solid performance for maritime AIS trajectory prediction, especially given the scale of trajectories (hundreds of meters).
Sequence to Sequence using Torch
Vessel trajectories are a type of geospatial temporal data derived from AIS (Automatic Identification System) signals. In this tutorial, we will go over the most common Machine Learning Library to process and model AIS trajectory data.
We will begin with PyTorch, a widely used deep learning library designed for building and training neural networks. Specifically, we will implement a recurrent neural network using LSTM (Long Short-Term Memory) to model sequential patterns in vessel movements.
We will utilize AISdb, a dedicated framework for querying, filtering, and preprocessing vessel trajectory data, to streamline data preparation for machine learning workflows.
First, let's import the libraries we'll be using throughout this tutorial. Our main tools will be NumPy and PyTorch, along with a few other standard libraries for data handling, model building, and visualization.
pandas, numpy: for handling tables and arrays
torch: for building and training deep learning models
sklearn: for data splitting and evaluation utilities
matplotlib: for visualizing model performance and outputs
Assuming you have the database ready, you can replace the file path and establish a connection.
We have processed a containing open-source AIS data from Marine Cadastre, covering January to March near Maine, United States.
To generate the query using AISdb, we use the function. All you have to change here is the DB_CONNECTION , START_DATE, END_DATE and the bounding coordinates.
Sample coordinates look like this on the map:
We use pyproj for the metric projection of the latitude and longitude values. Learn more .
We follow the listed steps to prepross the queried trajectory data:
Remove pings wrt to speed
encoding of tracks given a threshold
interpolation according to time (5 mins here)
group data based on mmsi
filter out mmsi's with less than 100 points
Convert lat lon to x & y on cartesian plane using pyroj
Use the sin cos value of cog as its a 360 degree value
drop NaN values
apply scaling to ensure value are normalized
The steps above are wrapped into the function defined as:
Next, we process all vessel tracks and split them into training and test sets, which are used for model training and evaluation.
For geospatial-temporal data, we typically use a sliding window approach, where each trajectory is segmented into input sequences of length X to predict the next Y steps. In this tutorial, we set X = 80 and Y = 2.
We then save all this data as well as the scalers (we'll use this towards the end in evaluation)
Now we can load the data and start experimenting with it. The same data can also be reused across different models we want to explore.
We use an attention-based encoder–decoder LSTM model for trajectory prediction. The model has two layers and incorporates teacher forcing, a strategy where the decoder is occasionally fed the ground-truth values during training. This helps stabilize learning and prevents the model from drifting too far when making multi-step predictions.
Two auxiliary functions are introduced to augment the original MSE loss. These additional terms are designed to better preserve the physical consistency and structural shape of the predicted trajectory.
Once the model is defined, the next step is to train it on our prepared dataset. Training involves iteratively feeding input sequences to the model, comparing its predictions against the ground truth, and updating the weights to reduce the error.
In our case, the loss function combines:
a data term (based on weighted coordinate errors and auxiliary features), and
a smoothness penalty (to encourage realistic vessel movement and reduce jitter in the predicted trajectory).
Finally, now that our model has been trained we use an evaluation function to check it in the different dataset we had stores earlier, as well as plot it on a map to see how the trajectory predictions look. Note- we dont just rely on the accuracy or training/testing results in numbers. There might be chances when the loss is showing in decimals but the coordinates are way far off. That is why we chose to plot it out on a map as well to check the predictions.
There are some debugging statements as well to see whether the scaling is right or wrong, the distace error etc. In this Model we have a metric distance error of only 800m.
Predicted vs True (lat/lon)
Summary (meters) t=0 mean error: 833.31 m mean over horizon: 833.31 m, median: 833.31 m
One of the most potent applications enabled by LLMs is the development of question-answering chatbots. These are applications that can answer questions about specific source information. This tutorial demonstrates how to build a chatbot that can answer questions about AISViz documentation using a technique known as Retrieval-Augmented Generation (RAG).
A typical RAG application has two main components:
Indexing: a pipeline for scraping data from documentation and indexing it.
This usually happens offline.
Retrieval and generation: the actual RAG chain, which takes the user query at runtime and retrieves the relevant data from the index, then passes it to the model.
The most common complete sequence from raw docs to answer looks like:
Scrape: First, we need to scrape all documentation pages. This includes the GitBook documentation and related pages.
Split: Text splitters break large documents into smaller chunks. This is useful both for indexing data and passing it into a model, since large chunks are harder to search over and won't fit in a model's finite context window.
Store: We need somewhere to store and index our splits, so that they can be searched over later. This is done using the Chroma vector database and embeddings.
Retrieve: Given a user input, relevant splits are retrieved from Chroma using similarity search.
Generate: An LLM produces an answer in response to a system prompt that combines both the question and the retrieved context.
This tutorial requires these dependencies:
You'll need a GOOGLE LLM API key (or another LLM provider). Set it as an environment variable:
We need to select three main components:
LLM: We'll use Google's Gemini models through LangChain
Embeddings: Hugging Face SentenceTransformers for creating document embeddings
Vector Store: Chroma for storing and searching document embeddings
We can create a simple indexing pipeline and RAG chain to do this in about 100 lines of code.
Scraping Documentation
We need to first scrape all the AISViz documentation pages.
Splitting documents
Our scraped documents can be quite long, so we need to split them into smaller chunks. We'll use a simple text splitter that breaks documents into chunks of specified size with some overlap.
Storing documents with SentenceTransformers
Now we need to create embeddings for our chunks using Hugging Face SentenceTransformers and store them in the Chroma vector database.
This completes the Indexing portion of the pipeline. At this point, we have a queryable vector store containing the chunked contents of all the documentation with embeddings created by SentenceTransformers. Given a user question, we should be able to return the most relevant snippets.
Now let's write the actual application logic. We aim to create a simple function that takes a user question, searches for relevant documents using SentenceTransformers embeddings, and generates an answer using Google Gemini. A high-level breakdown:
API Key Setup – Load LLM (Gemini in this example) API key from environment or prompt user.
Model Initialization – Wrap Google Gemini (gemini-2.5-flash) with LangChain.
Embedding – Convert the user's question into a vector with SentenceTransformers.
Retrieval – Query Chroma to fetch the top-k most relevant document chunks.
Context Building – Assemble retrieved docs + metadata into a context string.
Prompting LLM – Combine system + user prompts and send to LLM.
Answer Generation – Return a concise response along with sources and context.
Now let's create a simple web interface using Gradio so others can interact with our chatbot:
The full code can be found here:







































import io
import json
import random
from collections import defaultdict
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import numpy as np
import pyproj
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import joblib
import aisdb
from aisdb import DBConn, database
from aisdb.database import sqlfcn_callbacks
from aisdb.database.sqlfcn_callbacks import in_timerange
from aisdb.database.dbconn import PostgresDBConn
from aisdb.track_gen import TrackGen
import cartopy.crs as ccrs
import cartopy.feature as cfeatureDB_CONNECTION = "/home/sqlite_database_file.db" # replace with your data path
START_DATE = datetime(2018, 8, 1, hour=0) # starting at 12 midnight on 1st August 2018
END_DATE = datetime(2018, 8, 4, hour=2) #Ending at 2:00 am on 3rd August 2018
XMIN, YMIN, XMAX, YMAX =-64.828126,46.113933,-58.500001,49.619290 #Sample coordinates - x refers to longitude and y to latitude
# database connection
dbconn = DBConn(dbpath=DB_CONNECTION)
# Generating query to extract data between a given time range
qry = aisdb.DBQuery(dbconn=dbconn, callback=in_timerange,
start = START_DATE,
end=END_DATE,
xmin=XMIN, xmax=XMAX, ymin=YMIN, ymax=YMAX,
)
rowgen = qry.gen_qry(verbose=True) # generating query
tracks = TrackGen(rowgen, decimate=False) # Convert rows into tracks
# rowgen_ = qry.gen_qry(reaggregate_static=True, verbose=True) if you want metadata
#To not get an overfitted model, lets chose data from a completely different date to test
TEST_START_DATE = datetime(2018, 8, 5, hour=0)
TEST_END_DATE = datetime(2018,8,6, hour= 8)
test_qry = aisdb.DBQuery(dbconn=dbconn, callback=in_timerange,
start = TEST_START_DATE,
end = TEST_END_DATE,
xmin=XMIN, xmax=XMAX, ymin=YMIN, ymax=YMAX
)
test_tracks = TrackGen( (test_qry.gen_qry(verbose=True)), decimate=False)# --- Projection: Lat/Lon -> Cartesian (meters) ---
proj = pyproj.Proj(proj='utm', zone=20, ellps='WGS84')def preprocess_aisdb_tracks(tracks_gen, proj, sog_scaler=None, feature_scaler=None, fit_scaler=False):
# Keep as generator for AISdb functions
tracks_gen = aisdb.remove_pings_wrt_speed(tracks_gen, 0.1)
tracks_gen = aisdb.encode_greatcircledistance(tracks_gen,
distance_threshold=50000,
minscore=1e-5,
speed_threshold=50)
tracks_gen = aisdb.interp_time(tracks_gen, step=timedelta(minutes=5))
# Convert generator to list AFTER all AISdb steps
tracks = list(tracks_gen)
# Group by MMSI
tracks_by_mmsi = defaultdict(list)
for track in tracks:
tracks_by_mmsi[track['mmsi']].append(track)
# Keep only MMSIs with tracks >= 100 points
valid_tracks = []
for mmsi, mmsi_tracks in tracks_by_mmsi.items():
if all(len(t['time']) >= 100 for t in mmsi_tracks):
valid_tracks.extend(mmsi_tracks)
# Convert to DataFrame
rows = []
for track in valid_tracks:
mmsi = track['mmsi']
sog = track.get('sog', [np.nan]*len(track['time']))
cog = track.get('cog', [np.nan]*len(track['time']))
for i in range(len(track['time'])):
x, y = proj(track['lon'][i], track['lat'][i])
cog_rad = np.radians(cog[i]) if cog[i] is not None else np.nan
rows.append({
'mmsi': mmsi,
'x': x,
'y': y,
'sog': sog[i],
'cog_sin': np.sin(cog_rad) if not np.isnan(cog_rad) else np.nan,
'cog_cos': np.cos(cog_rad) if not np.isnan(cog_rad) else np.nan,
'timestamp': pd.to_datetime(track['time'][i], errors='coerce')
})
df = pd.DataFrame(rows)
# Drop rows with NaNs
df = df.replace([np.inf, -np.inf], np.nan)
df = df.dropna(subset=['x', 'y', 'sog', 'cog_sin', 'cog_cos'])
# Scale features
feature_cols = ['x', 'y', 'sog','cog_sin','cog_cos'] # only scale directions
if fit_scaler:
sog_scaler = RobustScaler()
df['sog_scaled'] = sog_scaler.fit_transform(df[['sog']])
feature_scaler = RobustScaler()
df[feature_cols] = feature_scaler.fit_transform(df[feature_cols])
else:
df['sog_scaled'] = sog_scaler.transform(df[['sog']])
df[feature_cols] = feature_scaler.transform(df[feature_cols])
return df, sog_scaler, feature_scalertrain_qry = aisdb.DBQuery(dbconn=dbconn, callback=in_timerange,
start=START_DATE, end=END_DATE,
xmin=XMIN, xmax=XMAX, ymin=YMIN, ymax=YMAX)
train_gen = TrackGen(train_qry.gen_qry(verbose=True), decimate=False)
test_qry = aisdb.DBQuery(dbconn=dbconn, callback=in_timerange,
start=TEST_START_DATE, end=TEST_END_DATE,
xmin=XMIN, xmax=XMAX, ymin=YMIN, ymax=YMAX)
test_gen = TrackGen(test_qry.gen_qry(verbose=True), decimate=False)
# --- Preprocess ---
train_df, sog_scaler, feature_scaler = preprocess_aisdb_tracks(train_gen, proj, fit_scaler=True)
test_df, _, _ = preprocess_aisdb_tracks(test_gen, proj, sog_scaler=sog_scaler,
feature_scaler=feature_scaler, fit_scaler=False)def create_sequences(df, features, input_size=80, output_size=2, step=1):
X_list, Y_list = [], []
for mmsi in df['mmsi'].unique():
sub = df[df['mmsi']==mmsi].sort_values('timestamp')[features].to_numpy()
for i in range(0, len(sub)-input_size-output_size+1, step):
X_list.append(sub[i:i+input_size])
Y_list.append(sub[i+input_size:i+input_size+output_size])
return torch.tensor(X_list, dtype=torch.float32), torch.tensor(Y_list, dtype=torch.float32)
features = ['x','y','cog_sin','cog_cos','sog_scaled']
mmsis = train_df['mmsi'].unique()
train_mmsi, val_mmsi = train_test_split(mmsis, test_size=0.2, random_state=42, shuffle=True)
train_X, train_Y = create_sequences(train_df[train_df['mmsi'].isin(train_mmsi)], features)
val_X, val_Y = create_sequences(train_df[train_df['mmsi'].isin(val_mmsi)], features)
test_X, test_Y = create_sequences(test_df, features)# --- save datasets ---
torch.save({
'train_X': train_X, 'train_Y': train_Y,
'val_X': val_X, 'val_Y': val_Y,
'test_X': test_X, 'test_Y': test_Y
}, 'datasets_seq2seq_cartesian5.pt')
# --- save scalers ---
joblib.dump(feature_scaler, "feature_scaler.pkl")
joblib.dump(sog_scaler, "sog_scaler.pkl")
# --- save projection parameters ---
proj_params = {'proj': 'utm', 'zone': 20, 'ellps': 'WGS84'}
with open("proj_params.json", "w") as f:
json.dump(proj_params, f)
data = torch.load('datasets_seq2seq_cartesian5.pt')
# scalers
feature_scaler = joblib.load("feature_scaler.pkl")
sog_scaler = joblib.load("sog_scaler.pkl")
# projection
with open("proj_params.json", "r") as f:
proj_params = json.load(f)
proj = pyproj.Proj(**proj_params)
train_ds = TensorDataset(data['train_X'], data['train_Y'])
val_ds = TensorDataset(data['val_X'], data['val_Y'])
test_ds = TensorDataset(data['test_X'], data['test_Y'])
batch_size = 64
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
test_dl = DataLoader(test_ds, batch_size=batch_size)class Seq2SeqLSTM(nn.Module):
def __init__(self, input_size, hidden_size, input_steps, output_steps):
super().__init__()
self.input_steps = input_steps
self.output_steps = output_steps
# Encoder
self.encoder = nn.LSTM(input_size, hidden_size, num_layers=2, dropout=0.3, batch_first=True)
# Decoder
self.decoder = nn.LSTMCell(input_size, hidden_size)
self.attn = nn.Linear(hidden_size + hidden_size, input_steps) # basic Bahdanau-ish
self.attn_combine = nn.Linear(hidden_size + input_size, input_size)
# Output projection (predict residual proposal for all features)
self.output_layer = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Linear(hidden_size // 2, input_size)
)
def forward(self, x, target_seq=None, teacher_forcing_ratio=0.5):
"""
x: [B, T_in, F]
target_seq: [B, T_out, F] (optional, for teacher forcing)
returns: [B, T_out, F] predicted absolute features (with x,y = last_obs + residuals)
"""
batch_size = x.size(0)
encoder_outputs, (h, c) = self.encoder(x) # encoder_outputs: [B, T_in, H]
h, c = h[-1], c[-1] # take final layer states -> [B, H]
last_obs = x[:, -1, :] # [B, F] last observed features (scaled)
last_xy = last_obs[:, :2] # [B, 2]
decoder_input = last_obs # start input is last observed feature vector
outputs = []
for t in range(self.output_steps):
# attention weights over encoder outputs: combine h and c to get context
attn_weights = torch.softmax(self.attn(torch.cat((h, c), dim=1)), dim=1) # [B, T_in]
context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1) # [B, H]
dec_in = torch.cat((decoder_input, context), dim=1) # [B, F+H]
dec_in = self.attn_combine(dec_in) # [B, F]
h, c = self.decoder(dec_in, (h, c)) # h: [B, H]
residual = self.output_layer(h) # [B, F] -- residual proposal for all features
# Apply residual only to x,y (first two dims); other features are predicted directly
out_xy = residual[:, :2] + last_xy # absolute x,y = last_xy + residual
out_rest = residual[:, 2:] # predicted auxiliary features
out = torch.cat([out_xy, out_rest], dim=1) # [B, F]
outputs.append(out.unsqueeze(1)) # accumulate
# Teacher forcing or autoregressive input
if self.training and (target_seq is not None) and (t < target_seq.size(1)) and (random.random() < teacher_forcing_ratio):
decoder_input = target_seq[:, t, :] # teacher forcing uses ground truth step (scaled)
# update last_xy to ground truth (so if teacher forced, we align residual base)
last_xy = decoder_input[:, :2]
else:
decoder_input = out
last_xy = out[:, :2] # update last_xy to predicted absolute xy
return torch.cat(outputs, dim=1) # [B, T_out, F]
def weighted_coord_loss(pred, target, coord_weight=5.0, reduction='mean'):
"""
Give higher weight to coordinate errors (first two dims).
pred/target: [B, T, F]
"""
coord_loss = F.smooth_l1_loss(pred[..., :2], target[..., :2], reduction=reduction)
aux_loss = F.smooth_l1_loss(pred[..., 2:], target[..., 2:], reduction=reduction)
return coord_weight * coord_loss + aux_loss
def xy_smoothness_loss(seq_full):
"""
Smoothness penalty only on x,y channels of a full sequence.
seq_full: [B, T, F] (F >= 2)
Returns scalar L2 of second differences on x,y.
"""
xy = seq_full[..., :2] # [B, T, 2]
v = xy[:, 1:, :] - xy[:, :-1, :] # velocity [B, T-1, 2]
a = v[:, 1:, :] - v[:, :-1, :] # acceleration [B, T-2, 2]
return (v**2).mean() * 0.05 + (a**2).mean() * 0.5
def train_model(model, loader, val_dl, optimizer, device, epochs=50,
coord_weight=5.0, smooth_w_init=1e-3):
best_loss = float('inf')
best_state = None
for epoch in range(epochs):
model.train()
total_loss = total_data_loss = total_smooth_loss = 0.0
for batch_x, batch_y in loader:
batch_x = batch_x.to(device) # [B, T_in, F]
batch_y = batch_y.to(device) # [B, T_out, F]
# --- Convert targets to residuals ---
# residual = y[t] - y[t-1], relative to last obs (x[:,-1])
y_start = batch_x[:, -1:, :2] # last observed xy [B,1,2]
residual_y = batch_y.clone()
residual_y[..., :2] = batch_y[..., :2] - torch.cat([y_start, batch_y[:, :-1, :2]], dim=1)
optimizer.zero_grad()
# Model predicts residuals
pred_residuals = model(batch_x, target_seq=residual_y, teacher_forcing_ratio=0.5) # [B, T_out, F]
# Data loss on residuals (only xy are residualized, other features can stay as-is)
loss_data = weighted_coord_loss(pred_residuals, residual_y, coord_weight=coord_weight)
# Smoothness: reconstruct absolute xy sequence first
pred_xy = torch.cumsum(pred_residuals[..., :2], dim=1) + y_start
full_seq = torch.cat([batch_x[..., :2], pred_xy], dim=1) # [B, T_in+T_out, 2]
loss_smooth = xy_smoothness_loss(full_seq)
# Decay smooth weight slightly over epochs
smooth_weight = smooth_w_init * max(0.1, 1.0 - epoch / 30.0)
loss = loss_data + smooth_weight * loss_smooth
# Stability guard
if torch.isnan(loss) or torch.isinf(loss):
print("Skipping batch due to NaN/inf loss.")
continue
loss.backward()
optimizer.step()
total_loss += loss.item()
total_data_loss += loss_data.item()
total_smooth_loss += loss_smooth.item()
avg_loss = total_loss / len(loader)
print(f"Epoch {epoch+1:02d} | Total: {avg_loss:.6f} | Data: {total_data_loss/len(loader):.6f} | Smooth: {total_smooth_loss/len(loader):.6f}")
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
for xb, yb in val_dl:
xb, yb = xb.to(device), yb.to(device)
# Compute residuals for validation
y_start = xb[:, -1:, :2]
residual_y = yb.clone()
residual_y[..., :2] = yb[..., :2] - torch.cat([y_start, yb[:, :-1, :2]], dim=1)
pred_residuals = model(xb, target_seq=residual_y, teacher_forcing_ratio=0.0)
data_loss = weighted_coord_loss(pred_residuals, residual_y, coord_weight=coord_weight)
pred_xy = torch.cumsum(pred_residuals[..., :2], dim=1) + y_start
full_seq = torch.cat([xb[..., :2], pred_xy], dim=1)
loss_smooth = xy_smoothness_loss(full_seq)
val_loss += (data_loss + smooth_weight * loss_smooth).item()
val_loss /= len(val_dl)
print(f" Val Loss: {val_loss:.6f}")
if val_loss < best_loss:
best_loss = val_loss
best_state = model.state_dict()
if best_state is not None:
torch.save(best_state, "best_model_seq2seq_residual_xy_08302.pth")
print("✅ Best model saved")
# Setting Seeds are extrememly important in research purposes to make sure the results are reproducible
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_X = data['train_X']
input_steps = 80
output_steps = 2
# Collapse batch and time dims to compute global min for each feature
flat_train_X = train_X.view(-1, train_X.shape[-1]) # shape: [N*T, 4]
input_size = 5
hidden_size = 64
model = Seq2SeqLSTM(input_size, hidden_size, input_steps, output_steps).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.HuberLoss()
train_model(model, train_dl,val_dl, optimizer, device)
def evaluate_with_errors(
model,
test_dl,
proj,
feature_scaler,
device,
num_batches=1,
outputs_are_residual_xy: bool = True,
dup_tol: float = 1e-4,
residual_decode_mode: str = "cumsum", # "cumsum", "independent", "stdonly"
residual_std: np.ndarray = None, # needed if residual_decode_mode="stdonly"
):
"""
Evaluates model and prints/plots errors.
Parameters
----------
outputs_are_residual_xy : bool
If True, model outputs residuals (scaled). Else, absolute coords (scaled).
residual_decode_mode : str
- "cumsum": cumulative sum of residuals (default).
- "independent": treat each residual as absolute offset from last obs.
- "stdonly": multiply by provided residual_std instead of scaler.scale_.
residual_std : np.ndarray
Std dev for decoding (meters), used only if residual_decode_mode="stdonly".
"""
model.eval()
errors_all = []
def inverse_xy_only(xy_scaled, scaler):
"""Inverse transform x,y only (works with StandardScaler/RobustScaler)."""
if hasattr(scaler, "mean_"):
center = scaler.mean_[:2]
scale = scaler.scale_[:2]
elif hasattr(scaler, "center_"):
center = scaler.center_[:2]
scale = scaler.scale_[:2]
else:
raise ValueError("Scaler type not supported for XY inversion")
return xy_scaled * scale + center
def haversine(lon1, lat1, lon2, lat2):
"""Distance (m) between two lon/lat points using haversine formula."""
R = 6371000.0
lon1, lat1, lon2, lat2 = map(np.radians, [lon1, lat1, lon2, lat2])
dlon = lon2 - lon1
dlat = lat2 - lat1
a = np.sin(dlat/2.0)**2 + np.cos(lat1)*np.cos(lat2)*np.sin(dlon/2.0)**2
return 2 * R * np.arcsin(np.sqrt(a))
with torch.no_grad():
batches = 0
for xb, yb in test_dl:
xb = xb.to(device); yb = yb.to(device)
pred = model(xb, teacher_forcing_ratio=0.0) # [B, T_out, F]
# first sample for visualization/diagnostics
input_seq = xb[0].cpu().numpy()
real_seq = yb[0].cpu().numpy()
pred_seq = pred[0].cpu().numpy()
input_xy_s = input_seq[:, :2]
real_xy_s = real_seq[:, :2]
pred_xy_s = pred_seq[:, :2]
# reconstruct predicted absolute trajectory
if outputs_are_residual_xy:
last_obs_xy_s = input_xy_s[-1]
last_obs_xy_m = inverse_xy_only(last_obs_xy_s, feature_scaler)
# choose decoding mode
if residual_decode_mode == "stdonly":
if residual_std is None:
raise ValueError("residual_std must be provided for 'stdonly' mode")
pred_resid_m = pred_xy_s * residual_std
pred_xy_m = np.cumsum(pred_resid_m, axis=0) + last_obs_xy_m
elif residual_decode_mode == "independent":
if hasattr(feature_scaler, "scale_"):
scale = feature_scaler.scale_[:2]
else:
raise ValueError("Scaler missing scale_ for residual decoding")
pred_resid_m = pred_xy_s * scale
pred_xy_m = pred_resid_m + last_obs_xy_m # no cumsum
else: # default: "cumsum"
if hasattr(feature_scaler, "scale_"):
scale = feature_scaler.scale_[:2]
else:
raise ValueError("Scaler missing scale_ for residual decoding")
pred_resid_m = pred_xy_s * scale
pred_xy_m = np.cumsum(pred_resid_m, axis=0) + last_obs_xy_m
else:
pred_xy_m = inverse_xy_only(pred_xy_s, feature_scaler)
# Handle duplicate first target
trimmed_real_xy_s = real_xy_s.copy()
dropped_duplicate = False
if trimmed_real_xy_s.shape[0] >= 1:
if np.allclose(trimmed_real_xy_s[0], input_xy_s[-1], atol=dup_tol):
trimmed_real_xy_s = trimmed_real_xy_s[1:]
dropped_duplicate = True
# Align lengths
len_pred = pred_xy_m.shape[0]
len_real = trimmed_real_xy_s.shape[0]
min_len = min(len_pred, len_real)
if min_len == 0:
print("No overlapping horizon — skipping batch.")
batches += 1
if batches >= num_batches:
break
else:
continue
pred_xy_m = pred_xy_m[:min_len]
real_xy_m = inverse_xy_only(trimmed_real_xy_s[:min_len], feature_scaler)
input_xy_m = inverse_xy_only(input_xy_s, feature_scaler)
# debug
print("\nDEBUG (scaled -> unscaled):")
print("last_obs_scaled:", input_xy_s[-1])
print("real_first_scaled:", real_xy_s[0])
print("pred_first_scaled_delta:", pred_xy_s[0])
print("dropped_duplicate_target:", dropped_duplicate)
print("pred length:", pred_xy_m.shape[0], "real length:", real_xy_m.shape[0])
print("last_obs_unscaled:", input_xy_m[-1])
print("pred_first_unscaled:", pred_xy_m[0])
# lon/lat conversion
lon_in, lat_in = proj(input_xy_m[:,0], input_xy_m[:,1], inverse=True)
lon_real, lat_real= proj(real_xy_m[:,0], real_xy_m[:,1], inverse=True)
lon_pred, lat_pred= proj(pred_xy_m[:,0], pred_xy_m[:,1], inverse=True)
# comparison table
print("\n=== Predicted vs True (lat/lon) ===")
print(f"{'t':>3} | {'lon_true':>9} | {'lat_true':>9} | {'lon_pred':>9} | {'lat_pred':>9} | {'err_m':>9}")
errors = []
for t in range(len(lon_real)):
err_m = haversine(lon_real[t], lat_real[t], lon_pred[t], lat_pred[t])
errors.append(err_m)
print(f"{t:3d} | {lon_real[t]:9.5f} | {lat_real[t]:9.5f} | {lon_pred[t]:9.5f} | {lat_pred[t]:9.5f} | {err_m:9.2f}")
errors_all.append(errors)
# plot
fig = plt.figure(figsize=(8, 6))
ax = plt.axes(projection=ccrs.PlateCarree())
# set extent dynamically around trajectory
all_lons = np.concatenate([lon_in, lon_real, lon_pred])
all_lats = np.concatenate([lat_in, lat_real, lat_pred])
lon_min, lon_max = all_lons.min() - 0.01, all_lons.max() + 0.01
lat_min, lat_max = all_lats.min() - 0.01, all_lats.max() + 0.01
ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
# add map features
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.LAND, facecolor="lightgray")
ax.add_feature(cfeature.OCEAN, facecolor="lightblue")
# plot trajectories
ax.plot(lon_in, lat_in, "o-", label="history", transform=ccrs.PlateCarree(),
markersize=6, linewidth=2)
ax.plot(lon_real, lat_real, "o-", label="true", transform=ccrs.PlateCarree(),
markersize=6, linewidth=2)
ax.plot(lon_pred, lat_pred, "x--", label="pred", transform=ccrs.PlateCarree(),
markersize=8, linewidth=2)
ax.legend()
plt.show()
batches += 1
if batches >= num_batches:
break
# summary
if errors_all:
errors_all = np.array(errors_all)
mean_per_t = errors_all.mean(axis=0)
print("\n=== Summary (meters) ===")
for t, v in enumerate(mean_per_t):
print(f"t={t} mean error: {v:.2f} m")
print(f"mean over horizon: {errors_all.mean():.2f} m, median: {np.median(errors_all):.2f} m")
# --- device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = torch.load("datasets_seq2seq_cartesian5.pt")
train_X, train_Y = data['train_X'], data['train_Y']
val_X, val_Y = data['val_X'], data['val_Y']
test_X, test_Y = data['test_X'], data['test_Y']
# --- load scalers ---
feature_scaler = joblib.load("feature_scaler.pkl") # fitted on x,y,sog,cog_sin,cog_cos
sog_scaler = joblib.load("sog_scaler.pkl")
# --- load projection ---
with open("proj_params.json", "r") as f:
proj_params = json.load(f)
proj = pyproj.Proj(**proj_params)
# --- rebuild & load model ---
input_size = train_X.shape[2] # features per timestep
input_steps = train_X.shape[1] # seq length in
output_steps = train_Y.shape[1] # seq length out
hidden_size = 64
num_layers = 2
best_model = Seq2SeqLSTM(
input_size=5,
hidden_size=64,
input_steps=80,
output_steps=2,
).to(device)
best_model.load_state_dict(torch.load(
"best_model_seq2seq_residual_xy_08302.pth", map_location=device
))
best_model.eval()
import numpy as np
# helper: inverse xy using your loaded feature_scaler (same as in evaluate)
def inverse_xy_only_np(xy_scaled, scaler):
if hasattr(scaler, "mean_"):
center = scaler.mean_[:2]
scale = scaler.scale_[:2]
elif hasattr(scaler, "center_"):
center = scaler.center_[:2]
scale = scaler.scale_[:2]
else:
raise ValueError("Scaler type not supported for XY inversion")
return xy_scaled * scale + center
N = train_X.shape[0]
all_resids = []
for i in range(N):
last_obs_s = train_X[i, -1, :2] # scaled
true_s = train_Y[i, :, :2] # scaled (T_out,2)
last_obs_m = inverse_xy_only_np(last_obs_s, feature_scaler)
true_m = inverse_xy_only_np(true_s, feature_scaler)
if true_m.shape[0] == 0:
continue
resid0 = true_m[0] - last_obs_m
if true_m.shape[0] > 1:
rest = np.diff(true_m, axis=0)
resids = np.vstack([resid0[None, :], rest]) # shape [T_out, 2]
else:
resids = resid0[None, :]
all_resids.append(resids)
# stack to compute std per axis across all time steps and samples
all_resids_flat = np.vstack(all_resids) # [sum_T, 2]
residual_std = np.std(all_resids_flat, axis=0) # [std_dx_m, std_dy_m]
print("Computed residual_std (meters):", residual_std)
evaluate_with_errors(
best_model, test_dl, proj, feature_scaler, device,
num_batches=1,
outputs_are_residual_xy=True,
residual_decode_mode="stdonly",
residual_std=residual_std
)
0
-61.69744
-61.70585
43.22816
43.22385
833.31 m
pip install langchain chromadb openai gradio beautifulsoup4 requestsexport GOOGLE_API_KEY="your-api-key-here"import getpass
import os
from sentence_transformers import SentenceTransformer
from langchain.chat_models import init_chat_model
if not os.environ.get("GOOGLE_API_KEY"):
os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")
model = init_chat_model("gemini-2.5-flash", model_provider="google_genai")
embeddings_model = SentenceTransformer('all-MiniLM-L6-v2')import os
import getpass
from bs4 import BeautifulSoup
import requests
from langchain.chat_models import init_chat_model
from langchain_community.vectorstores import Chroma
from langchain.docstore.document import Document
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from sentence_transformers import SentenceTransformer
from langchain.embeddings import HuggingFaceEmbeddings
# ----------------------
# Setup API keys
# ----------------------
if not os.environ.get("GOOGLE_API_KEY"):
os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")
# ----------------------
# Initialize models
# ----------------------
# Use HuggingFace wrapper so LangChain understands SentenceTransformer
embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# Gemini model (via LangChain)
llm = init_chat_model("gemini-2.5-flash", model_provider="google_genai")
# ----------------------
# Initialize Chroma vector store
# ----------------------
persist_directory = "./chroma_db"
vectorstore = Chroma(
collection_name="aisviz_docs",
embedding_function=embeddings_model,
persist_directory=persist_directory,
)
# ----------------------
# Scraper (example)
# ----------------------
def scrape_aisviz_docs(base_url="https://aisviz.example.com/docs"):
"""
Scrape AISViz docs and return list of LangChain Documents.
Adjust selectors based on actual site structure.
"""
response = requests.get(base_url)
response.raise_for_status()
soup = BeautifulSoup(response.text, "html.parser")
docs = []
for section in soup.find_all("div", class_="doc-section"):
text = section.get_text(strip=True)
docs.append(Document(page_content=text, metadata={"source": base_url}))
return docs
# Example: index docs once
def build_index():
docs = scrape_aisviz_docs()
vectorstore.add_documents(docs)
vectorstore.persist()
# ----------------------
# Retrieval QA
# ----------------------
QA_PROMPT = PromptTemplate(
input_variables=["context", "question"],
template="""
You are an assistant for question-answering tasks about AISViz documentation.
Use the following context to answer the user’s question.
If the answer is not contained in the context, say you don't know.
Context:
{context}
Question:
{question}
Answer concisely:
""",
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
chain_type="stuff",
chain_type_kwargs={"prompt": QA_PROMPT},
return_source_documents=True,
)
# ----------------------
# Usage
# ----------------------
def answer_question(question: str):
result = qa_chain({"query": question})
return result["result"], result["source_documents"]
if __name__ == "__main__":
# Uncomment below line if first time indexing
# build_index()
ans, sources = answer_question("What is AISViz used for?")
print("Answer:", ans)
print("Sources:", [s.metadata["source"] for s in sources])
import requests
from bs4 import BeautifulSoup
from urllib.parse import urljoin, urlparse
import time
def scrape_aisviz_docs():
"""Scrape all AISViz documentation pages"""
base_urls = [
"https://aisviz.gitbook.io/documentation",
"https://aisviz.cs.dal.ca",
]
scraped_content = []
visited_urls = set()
for base_url in base_urls:
# Get the main page
response = requests.get(base_url)
soup = BeautifulSoup(response.content, 'html.parser')
# Extract main content
main_content = soup.find('div', class_='page-content') or soup.find('main')
if main_content:
text = main_content.get_text(strip=True)
scraped_content.append({
'url': base_url,
'title': soup.find('title').get_text() if soup.find('title') else 'AISViz Documentation',
'content': text
})
# Find all links to other documentation pages
links = soup.find_all('a', href=True)
for link in links:
href = link['href']
full_url = urljoin(base_url, href)
# Only follow links within the same domain
if urlparse(full_url).netloc == urlparse(base_url).netloc and full_url not in visited_urls:
visited_urls.add(full_url)
try:
time.sleep(0.5) # Be respectful to the server
page_response = requests.get(full_url)
page_soup = BeautifulSoup(page_response.content, 'html.parser')
page_content = page_soup.find('div', class_='page-content') or page_soup.find('main')
if page_content:
text = page_content.get_text(strip=True)
scraped_content.append({
'url': full_url,
'title': page_soup.find('title').get_text() if page_soup.find('title') else 'AISViz Page',
'content': text
})
except Exception as e:
print(f"Error scraping {full_url}: {e}")
continue
return scraped_content
# Scrape all documentation
docs_data = scrape_aisviz_docs()
print(f"Scraped {len(docs_data)} pages from AISViz documentation")def split_text(text, chunk_size=1000, chunk_overlap=200):
"""Split text into overlapping chunks"""
chunks = []
start = 0
while start < len(text):
# Find the end of this chunk
end = start + chunk_size
# If this isn't the last chunk, try to break at a sentence or word boundary
if end < len(text):
# Look for sentence boundary
last_period = text.rfind('.', start, end)
last_newline = text.rfind('\n', start, end)
last_space = text.rfind(' ', start, end)
# Use the best boundary we can find
if last_period > start + chunk_size // 2:
end = last_period + 1
elif last_newline > start + chunk_size // 2:
end = last_newline
elif last_space > start + chunk_size // 2:
end = last_space
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
# Move start position for next chunk (with overlap)
start = end - chunk_overlap if end < len(text) else end
return chunks
# Split all documents into chunks
all_chunks = []
chunk_metadata = []
for doc_data in docs_data:
chunks = split_text(doc_data['content'])
for i, chunk in enumerate(chunks):
all_chunks.append(chunk)
chunk_metadata.append({
'source': doc_data['url'],
'title': doc_data['title'],
'chunk_id': i
})
print(f"Split documentation into {len(all_chunks)} chunks")from sentence_transformers import SentenceTransformer
import chromadb
# Initialize SentenceTransformer model
embeddings_model = SentenceTransformer('all-MiniLM-L6-v2')
# Create Chroma client and collection
chroma_client = chromadb.PersistentClient(path="./chroma_db")
collection = chroma_client.get_or_create_collection(name="aisviz_docs")
# Create embeddings for all chunks (this may take a few minutes)
print("Creating embeddings for all chunks...")
chunk_embeddings = embeddings_model.encode(all_chunks, show_progress_bar=True)
# Prepare data for Chroma
chunk_ids = [f"chunk_{i}" for i in range(len(all_chunks))]
metadatas = chunk_metadata
# Add everything to Chroma collection
collection.add(
embeddings=chunk_embeddings.tolist(),
documents=all_chunks,
metadatas=metadatas,
ids=chunk_ids
)
print("Documents indexed and stored in Chroma database")import getpass
import os
# ----------------------
# 1. API Key Setup
# ----------------------
# Check if GOOGLE_API_KEY is already set in the environment.
# If not, securely prompt the user to enter it (won’t show in terminal).
if not os.environ.get("GOOGLE_API_KEY"):
os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")
from langchain.chat_models import init_chat_model
# ----------------------
# 2. Initialize Gemini Model (via LangChain)
# ----------------------
# We create a chat model wrapper around Gemini.
# "gemini-2.5-flash" is a lightweight, fast generation model.
# LangChain handles the API calls under the hood.
model = init_chat_model("gemini-2.5-flash", model_provider="google_genai")
# ----------------------
# 3. Define RAG Function
# ----------------------
def answer_question(question, k=4):
"""
Answer a question using Retrieval-Augmented Generation (RAG)
over AISViz documentation stored in ChromaDB.
Args:
question (str): User question.
k (int): Number of top documents to retrieve.
Returns:
dict with keys: 'answer', 'sources', 'context'.
"""
# Step 1: Create an embedding for the question
# Encode the question into a dense vector using SentenceTransformers.
# This embedding will be used to search for semantically similar docs.
question_embedding = embeddings_model.encode([question])
# Step 2: Retrieve relevant documents from Chroma
# Query the Chroma vector store for the top-k most relevant docs.
results = collection.query(
query_embeddings=question_embedding.tolist(),
n_results=k
)
# Step 3: Build context from retrieved documents
# Extract both content and metadata for each retrieved doc.
retrieved_docs = results['documents'][0]
retrieved_metadata = results['metadatas'][0]
# Join all documents into one context string.
# Each doc is prefixed with its source/title for attribution.
context = "\n\n".join([
f"Source: {meta.get('title', 'AISViz Documentation')}\n{doc}"
for doc, meta in zip(retrieved_docs, retrieved_metadata)
])
# Step 4: Construct prompts
# System prompt: instructs Gemini to behave like a doc-based assistant.
system_prompt = """You are an assistant for question-answering tasks about AISViz documentation and maritime vessel tracking.
Use the following pieces of retrieved context to answer the question.
If you don't know the answer based on the context, just say that you don't know.
Keep the answer concise and helpful.
Always mention which sources you're referencing when possible."""
# User prompt: combines the retrieved context with the actual user question.
user_prompt = f"""Context:
{context}
Question: {question}
Answer:"""
# Step 5: Generate answer using Gemini
# Combine system + user prompts and send to Gemini.
full_prompt = f"{system_prompt}\n\n{user_prompt}"
try:
response = gemini_model.generate_content(full_prompt)
answer = response.text # Extract plain text from Gemini response
except Exception as e:
# If Gemini API call fails, catch error and return message
answer = f"Sorry, I encountered an error generating the response: {str(e)}"
# Return a structured result: answer text, sources, and raw context.
return {
'answer': answer,
'sources': [meta.get('source', '') for meta in retrieved_metadata],
'context': context
}
# ----------------------
# 4. Test the Function
# ----------------------
result = answer_question("What is AISViz?")
print("Answer:", result['answer'])
print("Sources:", result['sources'])
import gradio as gr
def chatbot_interface(message, history):
"""Interface function for Gradio chatbot"""
try:
result = answer_question(message)
response = result['answer']
# Add source information to the response
if result['sources']:
unique_sources = list(set(result['sources'][:3]))
sources_text = "\n\n**Sources:**\n" + "\n".join([f"- {source}" for source in unique_sources])
response += sources_text
return response
except Exception as e:
return f"Sorry, I encountered an error: {str(e)}"
# Create Gradio interface
demo = gr.ChatInterface(
fn=chatbot_interface,
title="🚢 AISViz Documentation Chatbot",
description="Ask questions about AISViz, AISdb, and maritime vessel tracking!",
examples=[
"What is AISViz?",
"How do I get started with AISdb?",
"What kind of data does AISViz work with?",
"How can I analyze vessel trajectories?"
],
retry_btn=None,
undo_btn="Delete Previous",
clear_btn="Clear History",
)
if __name__ == "__main__":
demo.launch()


