arrow-left
Only this pageAll pages
gitbookPowered by GitBook
1 of 33

Documentation

Loading...

Default Start

Loading...

Loading...

Loading...

Tutorials

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Machine Learning

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Loading...

Keep Exploring

AIS Hardware

How to deploy your own Automatic Identification System (AIS) receiver.

In addition to utilizing AIS data provided by Spirearrow-up-right for the Canadian coasts, you can install AIS receiver hardware to capture AIS data directly. The received data can be processed and stored in databases, which can then be used with AISdb. This approach offers additional data sources and allows users to collect and process their data (as illustrated in the pipeline below). Doing so allows you to customize your data collection efforts to meet specific needs and seamlessly integrate the data with AISdb for enhanced analysis and application. At the same time, you can share the data you collect with others.

Pipeline for capturing and sharing your own AIS data with a VHF Antenna and AISdb.

hashtag
Requirements

  • Raspberry Pi or other computers with internet working capability

  • 162MHz receiver, such as the

  • An antenna in the VHF frequency band (30MHz - 300MHz) e.g. Shakespeare QC-4 VHF Antenna

  • Optionally, you may want

An additional option includes free AIS receivers from . This option may require you to share the data with the organization to help expand its AIS-receiving network.

hashtag
Hardware Setup

  • When setting up your antenna, place it as high as possible and far away from obstructions and other equipment as is practical.

  • Connect the antenna to the receiver. If using a preamp filter, connect it between the antenna and the receiver.

  • Connect the receiver to your Linux device via a USB cable. If using a preamp filter, power it with a USB cable.

A visual example of the antenna hardware setup that MERIDIAN has available is as follows:

hashtag
Software Setup

hashtag
Quick Start

Connect the receiver to the Raspberry Pi via a USB port, and then run the configure_rpi.sh script. This will install the Rust toolchain, AISdb dispatcher, and AISdb system service (described below), allowing the receiver to start at boot.

hashtag
Custom Install

  1. Install Raspberry Pi OS with SSH enabled: Visit to download and install the Raspberry Pi OS. If using the RPi imager, please ensure you run it as an administrator.

  2. Connect the receiver: Attach the receiver to the Raspberry Pi using a USB cable. Then log in to the Raspberry Pi and update the system with the following command: sudo apt-get update

  3. Install the Rust toolchain: Install the Rust toolchain on the Raspberry Pi using the following command: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

This service will broadcast receiver input downstream to aisdb.meridian.cs.dal.ca via UDP. You can add additional endpoints at this stage; for more information, see mproxy-client --help. Additional AIS networking tools, such as mproxy-forward, mproxy-server, and mproxy-reverse, are available in the ./dispatcher source directory.

Next, link and enable the service on the Raspberry Pi to ensure the receiver starts at boot:

See more examples in docker-compose.yml

hashtag
💡 Common Issues

For some Raspberry hardware (such as the author's Raspberry Pi 4 Model B Rev 1.5), when connecting dAISy AIS Receivers, the device file in Linux used to represent a serial communication interface is not always "/dev/ttyACM0", as used in our ./ais_rcv.service.

You can check the actual device file in use by running:

For example, the author found that serial0 was linked to ttyS0 (i.e., ttyS0).

Simply changing /dev/ttyACM0 to /dev/ttyS0 may result in receiving garbled AIS signals. This is because the default baud rate settings are different. You can modify the default baud rate for ttyS0 using the following command:

Antenna mount

  • A filtered preamp, such as this one sold by Uputronicsarrow-up-right, to improve signal range and quality

  • Validate the hardware configuration

    • When connected via USB, the AIS receiver is typically found under /dev/ with a name beginning with ttyACM, for example /dev/ttyACM0. Ensure the device is listed in this directory.

    • To test the receiver, use the command sudo cat /dev/ttyACM0 to display its output. If all works as intended, you will see streams of bytes appearing on the screen.

    Afterward, log out and log back in to add Rust and Cargo to the system path.
  • Install the network client and dispatcher: (a) From crates.ioarrow-up-right, using cargo install mproxy-client(b) To install from the source, use the local path instead, e.g. cargo install --path ./dispatcher/client

  • Install systemd services: Set up new systemd services to run the AIS receiver and dispatcher. First, create a new text file ./ais_rcv.service with contents in the block below, replace User=ais and /home/ais with the username and home directory chosen in step 1.

  • Wegmatt dAISy 2 Channel Receiverarrow-up-right
    MarrineTrafficarrow-up-right
    https://www.raspberrypi.com/software/arrow-up-right
    Raspberry Pi (Image source: https://www.raspberrypi.com/products/raspberry-pi-3-model-b/arrow-up-right)
    MERIDIAN AIS hardware setup working at Sandy Cove in Halifax, NS - Canada.

    Introduction

    A little bit about where we stand.

    hashtag
    Overview

    Welcome to AISdbarrow-up-right - a comprehensive gateway for Automatic Identification System (AIS)arrow-up-right data uses and applications. AISdb is part of the Making Vessels Tracking Data Available to Everyone (AISViz)arrow-up-right project within the Marine Environmental Research Infrastructure for Data Integration and Application Network (MERIDIAN)arrow-up-right initiative at Dalhousie Universityarrow-up-right, designed to streamline the collection, processing, and analysis of AIS data, both in live-streaming scenarios and through historical records.

    The primary features AISdb provides include:

    SQL database for storing AIS position reports and vessel metadata: At the heart of AISdb is a database built on , giving users a friendly Python interface with which to interact. This interface simplifies tasks like database creation, data querying, processing, visualization, and exporting data to CSV format for diverse uses. To cater to advanced needs, AISdb supports using , offering superior concurrency handling and data-sharing capabilities for collaborative environments.

    Vessel data cleaning and trajectory modeling: AISdb includes vessel position cleaning and trajectory modeling features. This ensures that the data used for analyses is accurate and reliable, providing a solid foundation for further studies and applications.

    Integration with environmental context and external metadata: One of AISdb's unique features is its ability to enrich AIS datasets with environmental context. Users can seamlessly integrate oceanographic and bathymetric data in raster formats to bring depth to their analyses — quite literally, as the tool allows for incorporating seafloor depth data underneath vessel positions. Such versatility ensures that AISdb users can merge various environmental data points with AIS information, resulting in richer, multi-faceted maritime studies.

    Advanced features for maritime studies: AISdb offers network graph analysis, MMSI deduplication, interpolation, and other processing utilities. These features enable advanced data processing and analysis, supporting complex maritime studies and applications.

    Python interface and machine learning for vessel behavior modeling: AISdb includes a Python interface with a RUST background that paves the way for incorporating machine learning and deep learning techniques into vessel behavior modeling in an optimized way. This aspect of AISdb enhances the reproducibility and scalability of research, be it for academic exploration or practical industry applications.

    Research support: AISdb is more than just a storage and processing tool; it is a comprehensive platform designed to support research. Through a formal partnership with our research initiative (contact us for more information), academics, industry experts, and researchers can access extensive Canadian AIS data up to 100 km from the Canadian coastline. This dataset spans from January 2012 to the present and is updated monthly. AISdb offers raw and parsed data formats, eliminating preprocessing needs and streamlining AIS-related research.

    The AISViz team is based on the lab in collaboration with the research group at Dalhousie University. Funded by the , our mission revolves around democratizing AIS data use, making it accessible and understandable across multiple sectors, from government and academia to NGOs and the broader public. Besides, AISViz aims to introduce machine learning applications into AIS data handling of AISdb. This seeks to streamline user interactions with AIS data, enhancing the user experience by simplifying data access.

    Our commitment goes beyond just providing tools. Through AISViz, we're opening doors to innovative research and policy development, targeting environmental conservation, maritime traffic management, and much more. Whether you're a professional in the field, an educator, or a maritime enthusiast, AISViz and its components, including AISdb, offer the knowledge and technology to deepen your understanding and significantly impact marine vessel tracking and the well-being of our oceans.

    hashtag
    Our Team

    hashtag
    Active Members

    • Ruixin Song is a research assistant in the Computer Science Department at Dalhousie University. She has an M.Sc. in Computer Science and a B.Eng. in Spatial Information and Digital Technology. Her recent work focuses on marine traffic data analysis and physics-inspired models, particularly in relation to biological invasions in the ocean. Her research interests include mobility data mining, graph neural networks, and network flow and optimization problems.

      • Contact:

    • Gabriel Spadon

    Adjunct Members

    • Vaishnav Vaidheeswaran is a Master's student in Computer Science at Dalhousie University. He holds a B.Tech in Computer Science and Engineering and has three years of experience as a software engineer in India, working at cutting-edge startups. His ongoing work addresses incorporating spatial knowledge into trajectory forecasting models to reduce aleatoric uncertainty coming from stochastic interactions of the vessel with the environment. His research interests include large language models, graph neural networks, and reinforcement learning.

      • Contact:

    hashtag
    Former Members

    • Jinkun Chen is a Ph.D. student in Computer Science at Dalhousie University, specializing in Explainable AI, Natural Language Processing (NLP), and Visualization. He earned a bachelor's degree in Computer Science with First-Class Honours from Dalhousie University. Jinkun is actively involved in research, working on advancing fairness, responsibility, trustworthiness, and explainability within Large Language Models (LLMs) and AI.

    • Jay Kumar has a Ph.D. in Computer Science and Technology and was a postdoctoral fellow at the Department of Industrial Engineering at Dalhousie University. He has researched AI models for time-series data for over five years, focusing on Recurrent Neural models, probabilistic modeling, and feature engineering data analytics applied to ocean traffic. His research interests include Spatio-temporal Data Mining, Stochastic Modeling, Machine Learning, and Deep Learning.

    hashtag
    Contact

    We are passionate about fostering a collaborative and engaged community. We welcome your questions, insights, and feedback as vital components of our continuous improvement and innovation. Should you have any inquiries about AISdb, desire further information on our research, or wish to explore potential collaborations, please don't hesitate to contact us. Staying connected with users and researchers plays a crucial role in shaping the tool's development and ensuring it meets the diverse needs of our growing user base. You can easily contact our team via email or our . In addition to addressing individual queries, we are committed to organizing webinars and workshops and presenting at conferences to share knowledge, gather feedback, and widen our outreach (stay tuned for more information about these). Together, let's advance the understanding and utilization of marine data for a brighter, more informed future in ocean research and preservation.

    curl --proto '=https' --tlsv1.2 https://git-dev.cs.dal.ca/meridian/aisdb/-/raw/master/configure_rpi.sh | bash
    ./ais_rcv.service
    [Unit]
    Description="AISDB Receiver"
    After=network-online.target
    Documentation=https://aisdb.meridian.cs.dal.ca/doc/receiver.html
    
    [Service]
    Type=simple
    User=ais # Replace with your username
    ExecStart=/home/ais/.cargo/bin/mproxy-client --path /dev/ttyACM0 --server-addr 'aisdb.meridian.cs.dal.ca:9921' # Replace home directory
    Restart=always
    RestartSec=30
    
    [Install]
    WantedBy=default.target
    sudo systemctl enable systemd-networkd-wait-online.service
    sudo systemctl link ./ais_rcv.service
    sudo systemctl daemon-reload
    sudo systemctl enable ais_rcv
    sudo systemctl start ais_rcv
    ls -l /dev
    stty -F /dev/ttyS0 38400 cs8 -cstopb -parenb
    is an Assistant Professor at the Faculty of Computer Science at Dalhousie University, Halifax - NS, Canada. He holds a Ph.D. and an MSc in Computer Science from the University of Sao Paulo, Sao Carlos - SP, Brazil. His research focuses on spatio-temporal analytics, time-series forecasting, and complex network mining, with a deep involvement in data science and engineering, as well as geoinformatics.
    • Contact: spadon@dal.caenvelope

  • Ron Pelot has a Ph.D. in Management Sciences and is a Professor of Industrial Engineering at Dalhousie University. For the last 30 years, he and his team have been working on developing new software tools and analysis methods for maritime traffic safety, coastal zone security, and marine spills. Their research methods include spatial risk analysis, vessel traffic modeling, data processing, pattern analysis, location models for response resource allocation, safety analyses, and cumulative shipping impact studies.

    • Contact: ronald.pelot@dal.caenvelope

  • Parth Doshi
    is a Bachelor's student in Computer Science at Dalhousie University. His ongoing work addresses developing a Generative Adversarial Network aiming to generate vessel trajectories through agents as well as vessel spoofing detection. His research interests include time-series based forecasting and inversed reinforcement learning.
    • Contact: parth.doshi@dal.caenvelope

    Matthew Smith has a BSc degree in Applied Computer Science from Dalhousie University and specializes in managing and analyzing vessel tracking data. He is currently a Software Engineer at Radformation in Toronto, ON. Matt served as the AIS data manager on the MERIDIAN project, where he supported research groups across Canada in accessing and utilizing AIS data. The data was used to answer a range of scientific queries, including the impact of shipping on underwater noise pollution and the danger posed to endangered marine mammals by vessel collisions.
  • Casey Hilliard has a BSc degree in Computer Science from Dalhousie University and was a Senior Data Manager at the Institute for Big Data Analytics. He is currently a Chief Architect at GSTS (Global Spatial Technology Solutions) in Dartmouth, NS. Casey was a long-time research support staff member at the Institute and an expert in managing and using AIS vessel-tracking data. During his time, he assisted in advancing the Institute's research projects by managing and organizing large datasets, ensuring data integrity, and facilitating data usage in research.

  • Stan Matwin was the director of the Institute for Big Data Analytics, Dalhousie University, Halifax, Nova Scotia; he is a professor and Canada Research Chair (Tier 1) in Interpretability for Machine Learning. He is also a distinguished professor (Emeritus) at the University of Ottawa and a full professor with the Institute of Computer Science, Polish Academy of Sciences. His main research interests include big data, text mining, machine learning, and data privacy. He is a member of the Editorial Boards of IEEE Transactions on Knowledge and Data Engineering and the Journal of Intelligent Information Systems. He received the Lifetime Achievement Award of the Canadian AI Association (CAIAC).

  • SQLitearrow-up-right
    Postgresarrow-up-right
    Modeling and Analytics on Predictive Systems (MAPS)arrow-up-right
    Maritime Risk and Safety (MARS)arrow-up-right
    Department of Fisheries and Oceans Canada (DFO)arrow-up-right
    rsong@dal.caenvelope
    vaishnav@dal.caenvelope
    GitHub team platformarrow-up-right

    Quick Start

    A hands-on quick start guide for using AISdb.

    hashtag
    If you are new to AIS topics, click-herearrow-up-right to know about "Automatic Identification System (AIS)".

    Note: If you are starting from scratch, download the data ".db" file in our AISdb Tutorial GitHubarrow-up-right repository so that you can follow this guide properly.

    hashtag
    Python Environment and Installation

    To work with the AISdb Python package, please ensure you have Python version 3.8 or higher. If you plan to use SQLite, no additional installation is required, as it is included with Python by default. However, those who prefer using a PostgreSQL server must install it separately and enable the TimescaleDB extension to function correctly.

    hashtag
    User Installation

    The AISdb Python package can be conveniently installed using pip. It's highly recommended that a virtual Python environment be created and the package installed within it.

    You can test your installation by running the following commands:

    Notice that if you are running , ensure it is installed in the same environment as AISdb:

    The Python code in the rest of this document can be run in the Python environment you created.

    hashtag
    Development Installation

    For using nightly builds (not mandatory), you can install it from the source:

    Alternatively, you can use nightly builds (not mandatory) on Google Colab as follows:

    hashtag
    Database Handling

    AISdb supports SQLite and PostgreSQL databases. Since version 1.7.3, AISdb requires to function properly. To install TimescaleDB, follow these steps:

    Install TimescaleDB (PostgreSQL Extension)

    Enable the Extension in PostgreSQL

    Verify the Installation

    Restart PostgreSQL

    hashtag
    Connecting to a PostgreSQL database

    This option requires an optional dependency psycopg for interfacing with Postgres databases. Beware that Postgres accepts these Alternatively, a connection string may be used. Information on connection strings and Postgres URI format can be found .

    hashtag
    Attaching a SQLite database to AISdb

    Querying SQLite is as easy as informing the name of a ".db" file with the same entity-relationship as the databases supported by AIS, which are detailed in the section. We prepared an example SQLite database example_data.db based AIS data in a small region near Maine, United States in Jan 2022 from , which is available in AISdb GitHub repository.

    If you want to create your database using your data, we have a with examples that show you how to create an SQLite database from open-source data.

    hashtag
    Querying the Database

    Parameters for the database query can be defined using . Iterate over rows returned from the database for each vessel with . Convert the results into generator-yielding dictionaries with NumPy arrays describing position vectors, e.g., lon, lat, and time, using .

    The following query will return vessel trajectories from a given 1-hour time window:

    A specific region can be queried for AIS data using or one of its sub-classes to define a collection of shapely polygon features. For this example, the domain contains a single bounding box polygon derived from a longitude/latitude coordinate pair and radial distance specified in meters. If multiple features are included in the domain object, the domain boundaries will encompass the convex hull of all features.

    Additional query callbacks for filtering by region, timeframe, identifier, etc. can be found in and .

    hashtag
    Processing

    hashtag
    Voyage Modelling

    The above generator can be input into a processing function, yielding modified results. For example, to model the activity of vessels on a per-voyage or per-transit basis, each voyage is defined as a continuous vector of positions where the time between observed timestamps never exceeds 24 hours.

    hashtag
    Data cleaning and MMSI deduplication

    A common problem with AIS data is noise, where multiple vessels might broadcast using the same identifier (sometimes simultaneously). In such cases, AISdb can denoise the data:

    (1) Denoising with Encoder: The function checks the approximate distance between each vessel’s position. It separates vectors where a vessel couldn’t reasonably travel using the most direct path, such as speeds over 50 knots.

    (2) Distance and Speed Thresholds: A distance and speed threshold limits the maximum distance or time between messages that can be considered continuous.

    (3) Scoring and Segment Concatenation: A score is computed for each position delta, with sequential messages nearby at shorter intervals given a higher score. This score is calculated by dividing the Haversine distance by elapsed time. Any deltas with a score not reaching the minimum threshold are considered the start of a new segment. New segments are compared to the end of existing segments with the same vessel identifier; if the score exceeds the minimum, they are concatenated. If multiple segments meet the minimum score, the new segment is concatenated to the existing segment with the highest score.

    Notice that processing functions may be executed in sequence as a chain or pipeline, so after segmenting the individual voyages as shown above, results can be input into the encoder to remove noise and correct for vessels with duplicate identifiers.

    hashtag
    Interpolating, geofencing, and filtering

    Building on the above processing pipeline, the resulting cleaned trajectories can be geofenced and filtered for results contained by at least one domain polygon and interpolated for uniformity.

    Additional processing functions can be found in the module.

    hashtag
    Exporting as CSV

    The resulting processed voyage data can be exported in CSV format instead of being printed:

    hashtag
    Integration with external metadata

    AISDB supports integrating external data sources such as bathymetric charts and other raster grids.

    hashtag
    Bathymetric charts

    To determine the approximate ocean depth at each vessel position, the module can be used.

    Once the data has been downloaded, the class may be used to append bathymetric data to tracks in the context of a processing pipeline like the processing functions described above.

    Also, see for determining the approximate nearest distance to shore from vessel positions.

    hashtag
    Rasters

    Similarly, arbitrary raster coordinate-gridded data may be appended to vessel tracks

    hashtag
    Visualization

    AIS data from the database may be overlayed on a map such as the one shown above using the function. This function accepts a generator of track dictionaries such as those output by .

    For a complete plug-and-play solution, you may clone our .

    SQL Database

    hashtag
    Table Naming

    When loading data into the database, messages will be sorted into SQL tables determined by the message type and month. The names of these tables follow the following format, which {YYYYMM} indicates the table year and month in the format YYYYMM.

    Some additional tables containing computed data may be created depending on the indexes used. For example, an aggregate of vessel static data by month or a virtual table is used as a covering index.

    Database Loading

    This tutorial will guide you in using the AISdb package to load AIS data into a database and perform queries. We will begin with AISdb installation and environment setup, then proceed to examples of querying the loaded data and creating simple visualizations.

    hashtag
    Install Requirements

    Preparing a Python virtual environment for AISdb is a safe practice. It allows you to manage dependencies and prevent conflicts with other projects, ensuring a clean and isolated setup for your work with AISdb. Run these commands in your terminal based on the operating system you are using:

    Data Querying

    Data querying with AISdb involves setting up a connection to the database, defining query parameters, creating and executing the query, and processing the results. Following the previous tutorial, , we set up a database connection and made simple queries and visualizations. This tutorial will dig into data query functions and parameters and show you the queries you can make with AISdb.

    hashtag
    Query functions

    Data querying with AISdb includes two components: DBQuery and TrackGen. In this section, we will introduce each component with examples. Before starting data querying, please ensure you have connected to the database. If you have not done so, please follow the instructions and examples in

    Data Visualization

    This tutorial introduces visualization options for vessel trajectories processed using AISdb, including AISdb's integrated web interface and alternative approaches with popular Python visualization packages. Practical examples were provided for each tool, illustrating how to process and visualize AISdb tracks effectively.

    hashtag
    Internal visualization

    AISdb provides an integrated data visualization feature through the module, which allows users to generate interactive maps displaying vessel tracks. This built-in tool is designed for simplicity and ease of use, offering customizable visualizations directly from AIS data without requiring extensive setup.

    $ sudo cat /dev/ttyACM0
    !AIVDM,1,1,,A,B4eIh>@0<voAFw6HKAi7swf1lH@s,0*61
    !AIVDM,1,1,,A,14eH4HwvP0sLsMFISQQ@09Vr2<0f,0*7B
    !AIVDM,1,1,,A,14eGGT0301sM630IS2hUUavt2HAI,0*4A
    !AIVDM,1,1,,B,14eGdb0001sM5sjIS3C5:qpt0L0G,0*0C
    !AIVDM,1,1,,A,14eI3ihP14sM1PHIS0a<d?vt2L0R,0*4D
    !AIVDM,1,1,,B,14eI@F@000sLtgjISe<W9S4p0D0f,0*24
    !AIVDM,1,1,,B,B4eHt=@0:voCah6HRP1;?wg5oP06,0*7B
    !AIVDM,1,1,,A,B4eHWD009>oAeDVHIfm87wh7kP06,0*20
    Jupyterarrow-up-right
    TimescaleDB over PostgreSQLarrow-up-right
    keyword argumentsarrow-up-right
    .arrow-up-right
    herearrow-up-right
    SQL Database
    Marine Cadastrearrow-up-right
    Tutorialarrow-up-right
    tutorial
    aisdb.database.dbqry.DBQueryarrow-up-right
    aisdb.database.dbqry.DBQuery.gen_qry()arrow-up-right
    aisdb.track_gen.TrackGen()arrow-up-right
    aisdb.gis.Domainarrow-up-right
    aisdb.database.sql_query_stringsarrow-up-right
    aisdb.database.sqlfcn_callbacksarrow-up-right
    aisdb.denoising_encoder.encode_greatcircledistance()arrow-up-right
    aisdb.track_genarrow-up-right
    aisdb.webdata.bathymetryarrow-up-right
    Gebco()arrow-up-right
    TrackGen()arrow-up-right
    aisdb.webdata.shore_dist.ShoreDistarrow-up-right
    aisdb.web_interface.visualize()arrow-up-right
    aisdb.track_gen.TrackGen()arrow-up-right
    Google Colab Notebookarrow-up-right
    Visualization of vessel tracks within a defined time range
    Linux
    python -m venv AISdb   # create a python virtual environment
    source ./AISdb/bin/activate  # activate the virtual environment
    pip install aisdb  # from https://pypi.org/project/aisdb/
    Windows
    python -m venv AISdb
    ./AISdb/Scripts/activate  
    pip install aisdb
    python
    >>> import aisdb
    >>> aisdb.__version__  # should return '1.7.3' or newer
    source ./AISdb/bin/activate
    pip install jupyter
    jupyter notebook
    source AISdb/bin/activate  # On Windows use `AISdb\Scripts\activate`
    
    # Cloning the Repository and installing the package
    git clone https://github.com/AISViz/AISdb.git && cd aisdb
    
    # Windows users can instead download the installer:
    #   - https://forge.rust-lang.org/infra/other-installation-methods.html#rustup
    #   - https://static.rust-lang.org/rustup/dist/i686-pc-windows-gnu/rustup-init.exe
    curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs > install-rust.sh
    
    # Installing Rust and Maturin
    /bin/bash install-rust.sh -q -y
    pip install --upgrade maturin[patchelf]
    
    # Building AISdb package with Maturin
    maturin develop --release --extras=test,docs
    import os
    # Clone the AISdb repository from GitHub
    !git clone https://github.com/AISViz/AISdb.git
    # Install Rust using the official Rustup script
    !curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
    # Install Maturin to build the packages
    !pip install --upgrade maturin[patchelf]
    # Set up environment variables
    os.environ["PATH"] += os.pathsep + "/root/.cargo/bin"
    # Install wasm-pack for building WebAssembly packages
    !curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
    # Install wasm-pack as a Cargo dependency
    !cargo install wasm-pack
    # Setting environment variable for the virtual environment
    os.environ["VIRTUAL_ENV"] = "/usr/local"
    # Change directory to AISdb for building the package
    %cd AISdb
    # Build and install the AISdb package using Maturin
    !maturin develop --release --extras=test,docs
    $ sudo apt install -y timescaledb-postgresql-XX  # XX is the PG-SQL version
    > CREATE EXTENSION IF NOT EXISTS timescaledb;
    > SELECT * FROM timescaledb_information.version;
    $ sudo systemctl restart postgresql
    from aisdb.database.dbconn import PostgresDBConn
    
    # [OPTION 1]
    dbconn = PostgresDBConn(
        hostaddr='127.0.0.1',  # Replace this with the Postgres address (supports IPv6)
        port=5432,  # Replace this with the Postgres running port (if not the default)
        user='USERNAME',  # Replace this with the Postgres username
        password='PASSWORD',  # Replace this with your password
        dbname='DATABASE',  # Replace this with your database name
    )
    
    # [OPTION 2]
    dbconn = PostgresDBConn('postgresql://USERNAME:PASSWORD@HOST:PORT/DATABASE')
    from aisdb.database.dbconn import SQLiteDBConn 
    
    dbpath='example_data.db'
    dbconn = SQLiteDBConn(dbpath=dbpath)
    import aisdb
    import pandas as pd
    from datetime import datetime
    from collections import defaultdict
    
    dbpath = 'example_data.db'
    start_time = datetime.strptime("2022-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
    end_time = datetime.strptime("2022-01-01 0:59:59", '%Y-%m-%d %H:%M:%S')
    
    def data2frame(tracks):
        # Dictionary where for key/value
        ais_data = defaultdict(lambda: pd.DataFrame(
            columns = ['time', 'lat', 'lon', 'cog', 'rocog', 'sog', 'delta_sog']))
    
        for track in tracks:
            mmsi = track['mmsi']
            df = pd.DataFrame({
                'time': pd.to_datetime(track['time'], unit='s'),
                'lat': track['lat'], 'lon': track['lon'],
                'cog': track['cog'], 'sog': track['sog']
            })
    
            # Sort by time in descending order
            df = df.sort_values(by='time', ascending=False).reset_index(drop=True)
            # Compute the time difference in seconds
            df['time_diff'] = df['time'].diff().dt.total_seconds()
            # Compute RoCOG (Rate of Change of Course Over Ground)
            delta_cog = (df['cog'].diff() + 180) % 360 - 180
            df['rocog'] = delta_cog / df['time_diff']
            # Compute Delta SOG (Rate of Change of Speed Over Ground)
            df['delta_sog'] = df['sog'].diff() / df['time_diff']
            # Fill NaN values (first row) and infinite values (division by zero cases)
            df[['rocog', 'delta_sog']] = df[['rocog', 'delta_sog']].replace([float('inf'), float('-inf')], 0).fillna(0)
            # Drop unnecessary column
            df.drop(columns = ['time_diff'], inplace=True)
            # Store in the dictionary
            ais_data[mmsi] = df
        return ais_data
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            callback=aisdb.database.sqlfcn_callbacks.in_timerange_validmmsi,
        )
    
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        ais_data = data2frame(tracks)  # re-use previous function
    
    # Display DataFrames
    for key in ais_data.keys():
        print(ais_data[key])
    # a circle with a 100km radius around the location point
    domain = aisdb.DomainFromPoints(points=[(-69.34, 41.55)], radial_distances=[100000])
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
            ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
            callback=aisdb.database.sqlfcn_callbacks.in_validmmsi_bbox,
        )
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        ais_data = data2frame(tracks)  # re-use previous function
    
    # Display DataFrames
    for key in ais_data.keys():
        print(ais_data[key])
    from datetime import timedelta
    
    # Define a maximum time interval
    maxdelta = timedelta(hours=24)
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
            ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
            callback=aisdb.database.sqlfcn_callbacks.in_validmmsi_bbox,
        )
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
    
        # Split the generated tracks into segments
        track_segments = aisdb.split_timedelta(tracks, maxdelta)
        ais_data = data2frame(track_segments)  # re-use previous function
        
        # Display DataFrames
        for key in ais_data.keys():
            print(ais_data[key])
    distance_threshold = 20000  # the maximum allowed distance (meters) between consecutive AIS messages
    speed_threshold = 50        # the maximum allowed vessel speed in consecutive AIS messages
    minscore = 1e-6             # the minimum score threshold for track segment validation
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            callback=aisdb.database.sqlfcn_callbacks.in_timerange,
        )
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        
        # Encode the track segments to clean and validate the track data
        tracks_encoded = aisdb.encode_greatcircledistance(tracks, 
                                                          distance_threshold=distance_threshold, 
                                                          speed_threshold=speed_threshold, 
                                                          minscore=minscore)
        ais_data = data2frame(tracks_encoded)  # re-use previous function
        
        # Display DataFrames
        for key in ais_data.keys():
            print(ais_data[key])
    # Define a domain with a central point and corresponding radial distances
    domain = aisdb.DomainFromPoints(points=[(-69.34, 41.55),], radial_distances=[100000,])
    
    # Filter the encoded tracks to include only those within the specified domain
    tracks_filtered = aisdb.track_gen.zone_mask(tracks_encoded, domain)
    
    # Interpolate the filtered tracks with a specified time interval
    tracks_interp = aisdb.interp_time(tracks_filtered, step=timedelta(minutes=15))
    aisdb.write_csv(tracks_interp, 'ais_processed.csv')
    import aisdb
    
    # Set the data storage directory
    data_dir = './testdata/'
    
    # Download bathymetry grid from the internet
    bathy = aisdb.webdata.bathymetry.Gebco(data_dir=data_dir)
    bathy.fetch_bathymetry_grid()
    tracks = aisdb.TrackGen(qry.gen_qry(), decimate=False)
    tracks_bathymetry = bathy.merge_tracks(tracks) # merge tracks with bathymetry data
    tracks = aisdb.TrackGen(qry.gen_qry())
    raster_path './GMT_intermediate_coast_distance_01d.tif'
    
    # Load the raster file
    raster = aisdb.webdata.load_raster.RasterFile(raster_path)
    
    # Merge the generated tracks with the raster data
    tracks = raster.merge_tracks(tracks, new_track_key="coast_distance")
    from datetime import datetime, timedelta
    import aisdb
    from aisdb import DomainFromPoints
    
    dbpath='example_data.db'
    
    def color_tracks(tracks):
        ''' set the color of each vessel track using a color name or RGB value '''
        for track in tracks:
            track['color'] = 'blue' or 'rgb(0,0,255)'
            yield track
    
    # Set the start and end times for the query
    start_time = datetime.strptime("2022-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
    end_time = datetime.strptime("2022-01-31 00:00:00", '%Y-%m-%d %H:%M:%S')
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn,
            start=start_time,
            end=end_time,
            callback=aisdb.database.sqlfcn_callbacks.in_timerange_validmmsi,
        )
        rowgen = qry.gen_qry()
        
        # Convert queried rows to vessel trajectories
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        
        # Visualization
        aisdb.web_interface.visualize(
            tracks,
            visualearth=False,
            open_browser=True,
        )
    Additional tables are also included for storing data not directly derived from AIS message reports.

    For quick reference to data types and detailed explanations of these table entries, please see the Detailed Table Description.

    hashtag
    Custom SQL Queries

    In addition to querying the database using DBQueryarrow-up-right module, there is an option to customize the query with your own SQL code.

    Example of listing all the tables in your database:

    As messages are separated into tables by message type and month, queries spanning multiple message types or months should use UNIONs and JOINs to combine results as appropriate.

    Example of querying tables with `JOIN`:

    More information about SQL queries can be looked up from online tutorialsarrow-up-right.

    The R* tree virtual tables should be queried for AIS position reports instead of the default tables. Query performance can be significantly improved using the R* tree index when restricting output to a narrow range of MMSIs, timestamps, longitudes, and latitudes. However, querying a wide range will not yield much benefit. If custom indexes are required for specific manual queries, these should be defined on message tables 1_2_3, 5, 18, and 24 directly instead of upon the virtual tables.

    Timestamps are stored as epoch minutes in the database. To facilitate querying the database manually, use the dt_2_epoch() function to convert datetime values to epoch minutes and the epoch_2_dt() function to convert epoch minutes back to datetime values. Here is how you can use dt_2_epoch() with the example above:

    For more examples, please see the SQL code in aisdb_sql/arrow-up-right that is used to create database tables and associated queries.

    hashtag
    Detailed Table Description

    hashtag
    ais_{YYYYMM}_dynamic tables

    Column
    Data Type
    Description

    mmsi

    INTEGER

    Maritime Mobile Service Identity, a unique identifier for vessels.

    time

    INTEGER

    Timestamp of the AIS message, in epoch seconds.

    longitude

    REAL

    Longitude of the vessel in decimal degrees.

    latitude

    REAL

    Latitude of the vessel in decimal degrees.

    hashtag
    ais_{YYYYMM}_static tables

    Column
    Data Type
    Description

    mmsi

    INTEGER

    Maritime Mobile Service Identity, a unique identifier for vessels.

    time

    INTEGER

    Timestamp of the AIS message, in epoch seconds.

    vessel_name

    TEXT

    Name of the vessel.

    ship_type

    INTEGER

    Numeric code representing the type of ship.

    hashtag
    static_{YYYYMM}_aggregate tables

    Column
    Data Type
    Description

    mmsi

    INTEGER

    Maritime Mobile Service Identity, a unique identifier for vessels.

    imo

    INTEGER

    International Maritime Organization number, another unique vessel identifier.

    vessel_name

    TEXT

    Name of the vessel.

    ship_type

    INTEGER

    Numeric code representing the type of ship.

    Now you can check your installation by running:

    If you're using AISdb in Jupyterarrow-up-right Notebook, please include the following commands in your notebook cells:

    Then, import the required packages:

    hashtag
    Load AIS data into a database

    This section will show you how to efficiently load AIS data into a database.

    AISdb includes two database connection approaches:

    1. SQLite database connection; and,

    2. PostgreSQL database connection.

    hashtag
    SQLite database connection

    We are working with the SQLite database in most of the usage scenarios. Here is an example of loading data using the sample data included in the AISdb package:

    The code above decodes the AIS messages from the CSV file specified in filepaths and inserts them into the SQLite database connected via dbconn.

    Following is a quick example of a query and visualization of the data we just loaded with AISdb:

    Visualization of vessel tracks queried from SQLite database created from test data

    hashtag
    PostgreSQL database connection

    In addition to SQLite database connection, PostgreSQL is used in AISdb for its superior concurrency handling and data-sharing capabilities, making it suitable for collaborative environments and handling larger datasets efficiently. The structure and interactions with PostgreSQL are designed to provide robust and scalable solutions for AIS data storage and querying. For PostgreSQL, you need the psycopg2 library:

    To connect to a PostgreSQL database, AISdb uses the PostgresDBConn class:

    After establishing a connection to the PostgreSQL database, specifying the path of the data files, and using the aisdb.decode_msgs function for data processing, the following operations will be performed in order: data files processing, table creation, data insertion, and index rebuilding.

    Please pay close attention to the flags in aisdb.decode_msgs, as recent updates provide more flexibility for database configurations. These updates include support for ingesting NOAA data into the aisdb format and the option to structure tables using either the original B-Tree indexes or TimescaleDB’s structure when the extension is enabled. In particular, please take care of the following parameters:

    • source (str, optional) Specifies the data source to be processed and loaded into the database.

      • Options: "Spire", "NOAA"/"noaa", or leave empty.

      • Default: empty but will progress with Spire source.

    • raw_insertion (bool, optional)

      • If False, the function will drop and rebuild indexes to speed up data loading.

      • Default: True

    • timescaledb (bool, optional)

      • Set to True only if using the TimescaleDB extension in your PostgreSQL database.

      • Refer to the for proper setup and usage.

    hashtag
    Example: Processing a Full Year of Spire Data (2024)

    The following example demonstrates how to process and load Spire data for the entire year 2024 into an aisdb database with the TimescaleDB extension installed:

    Example of performing queries and visualizations with PostgreSQL database:

    Visualization of tracks queried from PostgreSQL database

    Moreover, if you wish to use your own AIS data to create and process a database with AISdb, please check out our instructional guide on data processing and database creation: Using Your AIS Data.

    or
    .

    hashtag
    Query database

    The DBQueryarrow-up-right class is used to create a query object that specifies the parameters for data retrieval, including the time range, spatial domain, and any filtering callbacks. Here is an example to create a DBQuery object and use parameters to specify the time range and geographical locations:

    hashtag
    Callback functions

    Callback functions are used in the DBQueryarrow-up-right class to filter data based on specific criteria. Some common callbacks include: in_bbox, in_time_bbox, valid_mmsi, and in_time_bbox_validmmsi. These callbacks ensure that the data retrieved matches the specific criteria defined in the query. Please find examples of using different callbacks with other parameters in Query types with practical examples.

    For more callback functions, refer to the API documentation here: API-Docarrow-up-right

    hashtag
    Method gen_qry

    The function gen_qryarrow-up-right is a method of the DBQueryarrow-up-right class in AISdb. It is responsible for generating rows of data that match the query criteria specified when creating the DBQuery object. This function acts as a generator, yielding one row at a time and efficiently handling large datasets.

    Two callback functions can be passed to gen_qry. They are:

    • crawl_dynamic : Iterates only over the position reports table. By default this is called.

    • crawl_dynamic_static: Iterates over both position reports and static messages tables.

    After creating the DBQuery object, we can generate rows with gen_qry :

    Each row from gen_qry is a tuple or dictionary representing a record in the database.

    hashtag
    Generate trajectories

    The TrackGenarrow-up-right class converts the generated rows from gen_qry into tracks (trajectories). It takes the row generator and, optionally, a decimate parameter to control point reduction. This conversion is essential for analyzing vessel movements, identifying patterns, and visualizing trajectories in later steps.

    Following the generated rows above, here is how to use the TrackGen class:

    The TrackGen class yields "tracks," which is a generator object. While iterating over tracks, each component is a dictionary representing a track for a specific vessel:

    This is the output with our sample data:

    hashtag
    Query types with practical examples

    In this section, we will provide practical examples of the most common querying types you can make using the DBQuery class, including querying within a time range, geographical areas, and tracking vessels by MMSI. Different queries can be achieved by changing the callbacksarrow-up-right parameters and other parameters defined in the DBQuery class. Then, we will use TrackGen to convert these query results into structured tracks for further analysis and visualization.

    First, we need to import the necessary packages and prepare data:

    hashtag
    Within time range

    Querying data within a specified time range can be done by using the in_timerange_validmmsiarrow-up-right callback in the DBQuery class:

    This will display the queried vessel tracks (within a time range, has a valid MMSI) on the map:

    Queried vessel tracks in specified time range

    You may find noise in some of the track data. In Data Cleaning, we introduced the de-noising methods in AISdb that can effectively remove unreasonable or error data points, ensuring more accurate and reliable vessel trajectories.

    hashtag
    Within bounding box

    In practical scenarios, people may have specific points/areas of interest. DBQuery includes parameters to define a bounding box and has relevant callbacks. Let's look at an example:

    This will show all the vessel tracks with valid MMSI in the defined bounding box:

    Queried vessel tracks within a defined bounding box

    hashtag
    Combination of multiple conditions

    In the above examples, we queried data in a time range and a geographical area. If you want to combine multiple query criteria, please check out available types of callbacksarrow-up-right in the API Docs. In the last example above, we can simply modify the callback type to obtain vessel tracks within both the time range and geographical area:

    The displayed vessel tracks:

    Queried vessel tracks within a defined bounding box and time range

    hashtag
    Filtering MMSI

    In addition to time and location range, you can track single and multiple vessel(s) of interest by specifying their MMSI in the query. Here is an example of tracking several vessels within a time range:

    Queried tracks of vessels of interest within a specified time range

    Database Loading
    Database Loading
    Quick Start
    Here is an example of using the web interface module to show queried data with colors. To display vessel tracks in a single color:
    Visualizing queried vessel tracks in a single color

    If you want to visualize vessel tracks in different colors based on MMSI, here's an example that demonstrates how to color-code tracks for easy identification:

    Visualizing vessel tracks in multiple colors based on MMSIs

    hashtag
    Alternative visualization

    Several alternative Python packages can be leveraged for users seeking more advanced or specialized visualization capabilities. For instance, Contextily, Basemap and Cartopy are excellent for creating detailed 2D plots, while Plotly offering powerful interactive graphs. Additionally, Kepler.gl caters to users needing dynamic, large-scale visualizations or 3D mapping. These alternatives allow for a deeper exploration of AIS data, offering flexibility in how data is presented and analyzed beyond the default capabilities of AISdb.

    hashtag
    Contextily + Matplotlib

    Visualization of vessel tracks with Contextily

    hashtag
    ⚠️ Basemap + Matplotlib

    Note: mpl_toolkits.basemap uses numpy v1, therefore, downgrade numpy to v1.26.4 to use Basemap. Else, refer to other alternatives mentioned such as Contextily!

    Visualization of vessel tracks with Basemap

    hashtag
    Cartopy

    Visualization of vessel tracks with Cartopy

    hashtag
    Plotly

    Interactive visualization of vessel tracks with Plotly
    Interactive visualization of vessel positions with Plotly

    hashtag
    Kepler.gl

    Interactive visualization of vessel track positions with Kepler.gl
    Heat map of vessel track density with Kepler.gl

    aisdb.web_interface.visualizearrow-up-right

    Track Interpolation

    Track interpolation with AISdb involves generating estimated positions of vessels at specific intervals when actual AIS data points are unavailable. This process is important for filling in gaps in the vessel's trajectory, which can occur due to signal loss, data filtering, or other disruptions.

    In this tutorial, we introduce different types of track interpolation implemented in AISdb with usage examples.

    hashtag
    Example data preparation

    First, we defined functions to transform and visualize the track data (a generator object), with options to view the data points or the tracks:

    We will use an actual track retrieved from the database for the examples in this tutorial and interpolate additional data points based on this track. The visualization will show the original track data points:

    hashtag
    Linear interpolation

    Linear interpolation estimates the vessel's position by drawing a straight line between two known points and calculating the positions at intermediate times. It is simple, fast, and straightforward but may not accurately represent complex movements.

    hashtag
    With equal time window intervals

    This method estimates the position of a vessel at regular time intervals (e.g., every 10 minutes). To perform linear interpolation with an equal time window on the track defined above:

    hashtag
    With equal distance intervals

    This method estimates the position of a vessel at regular spatial intervals (e.g., every 1 km along its path). To perform linear interpolation with equal distance intervals on the pseudo track defined above:

    hashtag
    Geodesic Track Interpolation

    This method estimates the positions of a vessel along a curved path using the principles of geometry, particularly involving great-circle routes.

    hashtag
    Cubic Spline Interpolation

    Given a set of data points, cubic spline interpolation fits a smooth curve through these points. The curve is represented as a series of cubic polynomials between each pair of data points. Each polynomial ensures a smooth curve at the data points (i.e., the first and second derivatives are continuous).

    hashtag
    Custom Track Interpolation

    In addition to the standard interpolation methods provided by AISdb, users can implement other interpolation techniques tailored to their specific analytical needs. For instance, B-spline (Basis Spline) interpolation is a mathematical technique that creates a smooth curve through data points. This smoothness is important in trajectory analysis as it avoids sharp, unrealistic turns and maintains a natural flow.

    Here is an implementation and example of using B-splines interpolation:

    Then, we can apply the function just implemented on the vessel tracks generator:

    The visualization of the interpolation shows as:

    Data Cleaning

    A common issue with AIS data is noise, where multiple vessels may broadcast using the same identifier simultaneously. AISdb incorporates data cleaning techniques to remove noise from vessel track data. For more details:

    Denoising with Encoder: The aisdb.denoising_encoder.encode_greatcircledistance()arrow-up-right function checks the approximate distance between each vessel’s position. It separates vectors where a vessel couldn’t reasonably travel using the most direct path, such as speeds over 50 knots.

    Distance and Speed Thresholds: Distance and speed thresholds limit the maximum distance or time between messages that can be considered continuous.

    Scoring and Segment Concatenation: A score is computed for each position delta, with sequential messages nearby at shorter intervals given a higher score. This score is calculated by dividing the Haversine distance by elapsed time. Any deltas with a score not reaching the minimum threshold are considered the start of a new segment. New segments are compared to the end of existing segments with the same vessel identifier; if the score exceeds the minimum, they are concatenated. If multiple segments meet the minimum score, the new segment is concatenated to the existing segment with the highest score.

    Processing functions may be executed in sequence as a processing chain or pipeline, so after segmenting the individual voyages, results can be input into the encoder to remove noise and correct for vessels with duplicate identifiers effectively.

    After segmentation and encoding, the tracks are shown as:

    For comparison, this is a shot of tracks before cleaning:

    Coast, shore, and ports

    Extracting distance features from and to points-of-interest using raster files.

    The distances of a vessel from the nearest shore, coast, and port are essential to perform particular tasks such as vessel behavior analysis, environmental monitoring, and maritime safety assessments. AISdb offers functions to acquire these distances for specific vessel positions. In this tutorial, we provide examples of calculating the distance in kilometers from shore and from the nearest port for a given point.

    First, we create a sample track:

    Here is what the sample track looks like:

    Sample track created for distance to shore and port calculation

    hashtag
    Distance from shore

    The class is used to calculate the nearest distance to shore, along with a raster file containing shore distance data. Currently, calling the get_distance function in ShoreDist will automatically download the shore distance raster file from our server. The function then merges the tracks in the provided track list, creates a new key, "km_from_shore", and stores the shore distance as the value for this key.

    hashtag
    Distance from coast

    Similar to acquiring the distance from shore, CoastDist is implemented to obtain the distance between the given track positions and the coastline.

    hashtag
    Distance from port

    Like the distances from the coast and shore, the class determines the distance between the track positions and the nearest ports.

    Using Your AIS Data

    In addition to accessing data stored on the AISdb server, you can download open-source AIS data or import your datasets for processing and analysis using AISdb. This tutorial guides you through downloading AIS data from popular websites, creating SQLite and PostgreSQL databases compatible with AISdb, and establishing database connections. We provide two examples: Downloading and Processing Individual Files, which demonstrates working with small data samples and creating an SQLite database, and Pipeline for Bulk File Downloads and Database Integration, which outlines our approach to handling multiple data file downloads and creating a PostgreSQL database.

    hashtag
    Data Source

    The U.S. vessel traffic data across user-defined geographies and periods are available at MarineCadastrearrow-up-right. This resource offers comprehensive AIS data that can be accessed for various maritime analysis purposes. We can tailor the dataset based on research needs by selecting specific regions and timeframes.

    hashtag
    Downloading and Processing Individual Files

    In the following example, we will show how to download and process a single data file and import the data to a newly created SQLite database.

    First, download the AIS data of the day using the curl command:

    Then, extract the downloaded ZIP file to a specific path:

    We will look into the number of columns in the downloaded CSV file.

    The required columns for AISdb have specific names and may differ from the imported dataset. Therefore, let's define the exact list of columns needed.

    Next, we update the name of columns in the existing dataframe df_ and change the time format as required. The timestamp of an AIS message is represented by BaseDateTime in the default format YYYY-MM-DDTHH:MM:SS. For AISdb, however, the time is represented in UNIX format. We now read the CSV and apply the necessary changes to the date format:

    In the code, we can see that we have mapped the column named accordingly. Additionally, the data type of some columns has also been changed. Additionally, the nm4 file usually contains raw messages, separating static messages from dynamic ones. However, the MarineCadastre Data does not have such a Message_ID to indicate the type. Thus, adding static messages is necessary for database creation so that a table related to metadata is created.

    Let's process the CSV to create an SQLite database using the aisdb package.

    A SQLite database has been created now.

    If prefer to progress to PostgreSQL database, defining postgresql string and progress with database connection:

    hashtag
    Pipeline for Bulk File Downloads and Database Integration

    This section provides an example of downloading and processing multiple files, creating a PostgreSQL database, and loading data into tables. The steps are outlined in a series of pipeline scripts available in this , which should be executed in the order indicated by their numbers.

    hashtag
    AIS Data Download and Extraction

    The first script, 0-download-ais.py, allows you to download AIS data from by specifying your needed years. If no years are specified, the script will default to downloading data for 2023. The downloaded ZIP files will be stored in a /data folder created in your current working directory. The second script, 1-zip2csv.py, extracts the CSV files from the downloaded ZIP files in /data and saves them in a new directory named /zip.

    To download and extract the data, simply run the two scripts in sequence:

    hashtag
    Preprocessing - Merge and Deduplication

    After downloading and extracting the AIS data, the 2-merge.py script consolidates the daily CSV files into monthly files while the 3-deduplicate.py script removes duplicate rows, retaining unique AIS messages. To perform the execution, simply run:

    The output of these two scripts will be cleaned CSV files, which will be stored in a new folder named /merged on your working directory.

    hashtag
    PostgreSQL Database Creation and Data Loading to Tables

    The final script, 4-postgresql-database.py, creates a PostgreSQL database with a specified name. To do this, the script connects to a PostgreSQL server, requiring you to provide your username and password to establish the connection. After creating the database, the script verifies that the number of columns in the CSV files matches the headers. The script creates a corresponding table in the database for each CSV file and loads the data into it. To run this script, you need to provide three command-line arguments: -dbname for the new database name, -user for your PostgreSQL username, and -password for your PostgreSQL password. Additionally, there are two optional arguments: -host (default is localhost) and -port (default is 5432), you can adjust the -host and -port values if your PostgreSQL server is running on a different host or port.

    When the program prompts that the task is finished, you may check the created database and loaded tables by connecting to the PostgreSQL server and using the psql command-line interface:

    Once connected, you can list all tables in the database by running the \dt command. In our example using 2023 AIS data (default download), the tables will appear as follows:

    Vessel Metadata

    This tutorial demonstrates how to access vessel metadata using MMSI and SQLite databases. In many cases, AIS messages do not contain metadata. Therefore, this tutorial introduces the built-in functions in AISdb and external APIs to extract detailed vessel information associated with a specific MMSI from web sources.

    hashtag
    Metadata Download

    We introduced two methods implemented in AISdb for scraping metadata: using session requests for direct access and employing web drivers with browsers to handle modern websites with dynamic content. Additionally, we provided an example of utilizing a third-party API to access vessel information.

    Vessel Speed

    In AISdb, the speed of a vessel is calculated using the aisdb.gis.delta_knots function, which computes the speed over ground (SOG) in knots between consecutive positions within a given track. This calculation is important for the , as it compares the vessel's speed against a set threshold to aid in the data cleaning process.

    Vessel speed calculation requires the distance the vessel has traveled between two consecutive positions and the time interval. This distance is computed using the function, and the time interval is simply the difference in timestamps between the two consecutive AIS position reports. The speed is then computed using the formula:

    The factor 1.9438445 converts the speed from meters per second to knots, the standard speed unit used in maritime contexts.

    Haversine Distance

    AISdb includes a function called aisdb.gis.delta_meters that calculates the Haversine distance in meters between consecutive positions within a vessel track. This function is essential for analyzing vessel movement patterns and ensuring accurate distance calculations on the Earth's curved surface. It is also integrated into the , which compares distances against a threshold to aid in the data-cleaning process.

    Here is an example of calculating the Haversine distance between each pair of consecutive points on a track:

    If we visualize this track on the map, we can observe:

    AIS Data to CSV

    Building on the previous section, where we used AIS data to create AISdb databases, users can export AIS data from these databases into CSV format. In this section, we provide examples of exporting data from SQLite or PostgreSQL databases into CSV files. While we demonstrate these operations using internal data, you can apply the same techniques to your databases.

    hashtag
    Export CSV from SQLite Database

    In the first example, we connected to a SQLite database, queried data in a specific time range and area of interest, and then exported the queried data to a CSV file:

    Now we can check the data in the exported CSV file:

    Decimation with AISdb

    Automatic Identification System (AIS) data provides a wealth of insights into maritime activities, including vessel movements and traffic patterns. However, the massive volume of AIS data, often consisting of millions or even billions of GPS position points, can be overwhelming somehow. Processing and visualizing this raw data directly can be computationally expensive, slow, and difficult to interpret.

    This is where the AISdb's decimation comes into play - it helps users efficiently reduce data clutter, making it easier to extract and focus on the most relevant information.

    hashtag
    What is Decimation in the Context of AIS Tracks?

    ais_{YYYYMM}_static  # table with static AIS messages
    ais_{YYYYMM}_dynamic # table with dynamic AIS message
    static_{YYYYMM}_aggregate # table of aggregated static vessel data
    coarsetype_ref # a reference table that maps numeric ship type codes to their descriptions hashmap
    import sqlite3
    
    dbpath='YOUR_DATABASE.db' # Define the path to your database
    
    # Connect to the database
    connection = sqlite3.connect(dbpath)
    
    # Create a cursor object
    cursor = connection.cursor()
    
    # Query to list all tables
    query = "SELECT name FROM sqlite_master WHERE type='table';"
    cursor.execute(query)
    
    # Fetch the results
    tables = cursor.fetchall()
    
    # Print the names of the tables
    print("Tables in the database:")
    for table in tables:
        print(table[0])
    
    # Close the connection
    connection.close()
    import sqlite3
    
    # Connect to the database
    connection = sqlite3.connect('YOUR_DATABASE.db')
    
    # Create a cursor object
    cursor = connection.cursor()
    
    # Define the JOIN SQL query
    query = f"""
    SELECT
        d.mmsi, 
        d.time,
        d.longitude,
        d.latitude,
        d.sog,
        d.cog,
        s.vessel_name,
        s.ship_type
    FROM ais_{YYYYMM}_dynamic d
    LEFT JOIN ais_{YYYYMM}_static s ON d.mmsi = s.mmsi
    WHERE d.time BETWEEN 1707033659 AND 1708176856  -- Filter by time range
      AND d.longitude BETWEEN -68 AND -56           -- Filter by geographical area
      AND d.latitude BETWEEN 45 AND 51.5;
    """
    
    # Execute the query
    cursor.execute(query)
    
    # Fetch the results
    results = cursor.fetchall()
    
    # Print the results
    for row in results:
        print(row)
    
    # Close the connection
    connection.close()
    from aisdb.gis import dt_2_epoch
    
    # Define the datetime range
    start_datetime = datetime(2018, 1, 1, 0, 0, 0)
    end_datetime = datetime(2018, 1, 1, 1, 59, 59)
    
    # Convert datetime to epoch time
    start_epoch = dt_2_epoch(start_datetime)
    end_epoch = dt_2_epoch(end_datetime)
    
    # Connect to the database
    connection = sqlite3.connect('YOUR_DATABASE.db')
    
    # Create a cursor object
    cursor = connection.cursor()
    
    # Define the JOIN SQL query using an epoch time range
    query = f"""
    SELECT
        d.mmsi, 
        d.time,
        d.longitude,
        d.latitude,
        d.sog,
        d.cog,
        s.vessel_name,
        s.ship_type
    FROM ais_201801_dynamic d
    LEFT JOIN ais_201801_static s ON d.mmsi = s.mmsi
    WHERE d.time BETWEEN {start_epoch} AND {end_epoch}  -- Filter by time range
      AND d.longitude BETWEEN -68 AND -56           -- Filter by geographical area
      AND d.latitude BETWEEN 45 AND 51.5;
    """
    
    # Execute the query
    cursor.execute(query)
    
    # Fetch the results
    results = cursor.fetchall()
    
    # Print the results
    for row in results:
        print(row)
    
    # Close the connection
    connection.close()
    Linux
    python -m venv AISdb         # create a python virtual environment
    source ./AISdb/bin/activate  # activate the virtual environment
    pip install aisdb            # from https://pypi.org/project/aisdb/
    Windows
    python -m venv AISdb         # create a virtual environment
    ./AISdb/Scripts/activate     # activate the virtual environment
    pip install aisdb            # install the AISdb package using pip
    $ python
    >>> import aisdb
    >>> aisdb.__version__        # should return '1.7.0' or newer
    # install nest-asyncio for enabling asyncio.run() in Jupyter Notebook
    %pip install nest-asyncio
    
    # Some of the systems may show the following error when running the user interface:
    # urllib3 v2.0 only supports OpenSSL 1.1.1+; currently, the 'SSL' module is compiled with 'LibreSSL 2.8.3'.
    # install urllib3 v1.26.6 to avoid this error
    %pip install urllib3==1.26.6
    from datetime import datetime, timedelta
    import os
    import aisdb
    import nest_asyncio
    nest_asyncio.apply()
    # List the test data files included in the package
    print(os.listdir(os.path.join(aisdb.sqlpath, '..', 'tests', 'testdata')))
    # You will see the print result: 
    # ['test_data_20210701.csv', 'test_data_20211101.nm4', 'test_data_20211101.nm4.gz']
    
    # Set the path for the SQLite database file to be used
    dbpath = './test_database.db'
    
    # Use test_data_20210701.csv as the test data
    filepaths = [os.path.join(aisdb.sqlpath, '..', 'tests', 'testdata', 'test_data_20210701.csv')]
    with aisdb.DBConn(dbpath = dbpath) as dbconn:
        aisdb.decode_msgs(filepaths=filepaths, dbconn=dbconn, source='TESTING')
    start_time = datetime.strptime("2021-07-01 00:00:00", '%Y-%m-%d %H:%M:%S')
    end_time = datetime.strptime("2021-07-02 00:00:00", '%Y-%m-%d %H:%M:%S')
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn,
            dbpath='./AIS2.db',
            callback=aisdb.database.sql_query_strings.in_timerange,
            start=start_time,
            end=end_time,
        )
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
    
        if __name__ == '__main__':
            aisdb.web_interface.visualize(
                tracks,
                visualearth=True,
                open_browser=True,
            )
    pip install psycopg2
    from aisdb.database.dbconn import PostgresDBConn
    
    # Option 1: Using keyword arguments
    dbconn = PostgresDBConn(
        hostaddr='127.0.0.1',      # Replace with the PostgreSQL address
        port=5432,                 # Replace with the PostgreSQL running port
        user='USERNAME',           # Replace with the PostgreSQL username
        password='PASSWORD',       # Replace with your password
        dbname='aisviz'            # Replace with your database name
    )
    
    # Option 2: Using a connection string
    dbconn = PostgresDBConn('postgresql://USERNAME:PASSWORD@HOST:PORT/DATABASE')
    start_year = 2024
    end_year = 2024
    start_month = 1
    end_month = 12
    
    overall_start_time = time.time()
    
    for year in range(start_year, end_year + 1):
        for month in range(start_month, end_month + 1):
            print(f'Loading {year}{month:02d}')
            month_start_time = time.time()
    
            filepaths = aisdb.glob_files(f'/slow-array/Spire/{year}{month:02d}/','.zip')
            filepaths = sorted([f for f in filepaths if f'{year}{month:02d}' in f])
            print(f'Number of files: {len(filepaths)}')
    
            with aisdb.PostgresDBConn(libpq_connstring=psql_conn_string) as dbconn:
                try:
                    aisdb.decode_msgs(filepaths,
                                    dbconn=dbconn,
                                    source='Spire',
                                    verbose=True,
                                    skip_checksum=True,
                                    raw_insertion=True,
                                    workers=6,
                                    timescaledb=True,
                            )
                except Exception as e:
                    print(f'Error loading {year}{month:02d}: {e}')
                    continue
    from aisdb.gis import DomainFromPoints
    from aisdb.database.dbqry import DBQuery
    from datetime import datetime
    
    # Define a spatial domain centered around the point (-63.6, 44.6) with a radial distance of 50000 meters.
    domain = DomainFromPoints(points=[(-63.6, 44.6)], radial_distances=[50000])
    
    # Create a query object to fetch AIS data within the specified time range and spatial domain.
    qry = DBQuery(
        dbconn=dbconn,
        start=datetime(2023, 1, 1), end=datetime(2023, 2, 1),
        xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
        ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
        callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi
    )
    
    # Generate rows from the query
    rowgen = qry.gen_qry()
    
    # Convert the generated rows into tracks
    tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
    
    # Visualize the tracks on a map
    aisdb.web_interface.visualize(
        tracks,           # The tracks (trajectories) to visualize.
        domain=domain,    # The spatial domain to use for the visualization.
        visualearth=True, # If True, use Visual Earth for the map background.
        open_browser=True # If True, automatically open the visualization in a web browser.
    )
    from aisdb.database.dbqry import DBQuery
    
    # Specify database path
    dbpath = ...
    
    # Specify constraints (optional)
    start_time = ...
    end_time = ...
    domain = ...
    
    # Create a query object to fetch data within time and geographical range
    qry = DBQuery(
        dbconn=dbconn,                  # Database connection object
        start=start_time,               # Start time for the query
        end=end_time,                   # End time for the query
        xmin=domain.boundary['xmin'],   # Minimum longitude of the domain
        xmax=domain.boundary['xmax'],   # Maximum longitude of the domain
        ymin=domain.boundary['ymin'],   # Minimum latitude of the domain
        ymax=domain.boundary['ymax'],   # Maximum latitude of the domain
        callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi  # Callback function to filter data
    )
    # Generate rows from the query
    rowgen = qry.gen_qry(fcn=sqlfcn.crawl_dynamic_static) # callback parameter is optional
    
    # Process the generated rows as needed
    for row in rowgen:
        print(row)
    from aisdb.track_gen import TrackGen
    
    # Convert the generated rows into tracks
    tracks = TrackGen(rowgen, decimate=False)
    for track in tracks:
        mmsi = track['mmsi']
        lons = track['lon']
        lats = track['lat']
        speeds = track['sog']
        
        print(f"Track for vessel MMSI {mmsi}:")
        for lon, lat, speed in zip(lons[:3], lats[:3], speeds[:3]):
            print(f" - Lon: {lon}, Lat: {lat}, Speed: {speed}")
        break  # Exit after the first track
    Track for vessel MMSI 316004240:
     - Lon: -63.54868698120117, Lat: 44.61691665649414, Speed: 7.199999809265137
     - Lon: -63.54880905151367, Lat: 44.61708450317383, Speed: 7.099999904632568
     - Lon: -63.55659866333008, Lat: 44.626953125, Speed: 1.5
    import os
    import aisdb
    from datetime import datetime, timedelta
    from aisdb import DBConn, DBQuery, DomainFromPoints
    
    dbpath='YOUR_DATABASE.db' # Define the path to your database
    start_time = datetime.strptime("2018-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
    end_time = datetime.strptime("2018-01-02 00:00:00", '%Y-%m-%d %H:%M:%S')
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            callback=aisdb.database.sqlfcn_callbacks.in_timerange_validmmsi,
        )
        rowgen = qry.gen_qry()
        
        # Convert queried rows to vessel trajectories
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        
        # Visualization
        aisdb.web_interface.visualize(
            tracks,
            domain=domain,
            visualearth=True,
            open_browser=True,
        )
    domain = DomainFromPoints(points=[(-63.6, 44.6)], radial_distances=[50000]) # a circle with a 100km radius around the location point
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
            ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
            callback=aisdb.database.sqlfcn_callbacks.in_validmmsi_bbox,
        )
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        
        aisdb.web_interface.visualize(
            tracks,
            domain=domain,
            visualearth=True,
            open_browser=True,
        )
    callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi
    import random
    
    def assign_colors(mmsi_list):
        colors = {}
        for mmsi in mmsi_list:
            colors[mmsi] = "#{:06x}".format(random.randint(0, 0xFFFFFF))  # Random color in hex
        return colors
    
    # Create a function to color tracks
    def color_tracks(tracks, colors):
        colored_tracks = []
        for track in tracks:
            mmsi = track['mmsi']
            color = colors.get(mmsi, "#000000")  # Default to black if no color assigned
            track['color'] = color
            colored_tracks.append(track)
        return colored_tracks
    
    # Set the start and end times for the query
    start_time = datetime.strptime("2018-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
    end_time = datetime.strptime("2018-12-31 00:00:00", '%Y-%m-%d %H:%M:%S')
    
    # Create a list of vessel MMSIs you want to track 
    MMSI = [636017611,636018124,636018253]
    
    # Assign colors to each MMSI
    colors = assign_colors(MMSI)
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time, mmsis = MMSI,
            callback=aisdb.database.sqlfcn_callbacks.in_timerange_inmmsi,
        )
        rowgen = qry.gen_qry()
        
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        colored_tracks = color_tracks(tracks, colors)
    
        # Visualizing the tracks
        aisdb.web_interface.visualize(
            colored_tracks,
            visualearth=True,
            open_browser=True,
        )
    import aisdb
    from datetime import datetime
    from aisdb.database.dbconn import SQLiteDBConn
    from aisdb import DBConn, DBQuery, DomainFromPoints
    
    import nest_asyncio
    nest_asyncio.apply()
    
    dbpath='YOUR_DATABASE.db' # Define the path to your database
    
    # Set the start and end times for the query
    start_time = datetime.strptime("2018-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
    end_time = datetime.strptime("2018-01-03 00:00:00", '%Y-%m-%d %H:%M:%S')
    
    # Define a circle with a 100km radius around the location point
    domain = DomainFromPoints(points=[(-63.6, 44.6)], radial_distances=[100000]) 
    
    def color_tracks(tracks):
        """ Set the color of each vessel track using a color name or RGB value. """
        for track in tracks:
            track['color'] = 'yellow'
            yield track
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
            ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
            callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
        )
        rowgen = qry.gen_qry()
        
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        colored_tracks = color_tracks(tracks)
    
        # Visualization
        aisdb.web_interface.visualize(
            colored_tracks,
            domain=domain,
            visualearth=True,
            open_browser=True,
        )
    import random
    
    def color_tracks2(tracks):
        colors = {}
        for track in tracks:
            mmsi = track.get('mmsi')
            if mmsi not in colors:
                # Assign a random color to this MMSI if not already assigned
                colors[mmsi] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
                track['color'] = colors[mmsi] # Set the color for the current track
            yield track
    
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
            ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
            callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
        )
        rowgen = qry.gen_qry()
        
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        colored_tracks = list(color_tracks2(tracks))
    
        # Visualization
        aisdb.web_interface.visualize(
            colored_tracks,
            domain=domain,
            visualearth=True,
            open_browser=True,
        )
    import aisdb
    from datetime import datetime
    from aisdb.database.dbconn import SQLiteDBConn
    from aisdb import DBConn, DBQuery, DomainFromPoints
    import contextily as cx
    import matplotlib.pyplot as plt
    import random 
    import nest_asyncio
    nest_asyncio.apply()
    
    dbpath='YOUR_DATABASE.db' # Define the path to your database
    
    # Set the start and end times for the query
    start_time = datetime.strptime("2018-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
    end_time = datetime.strptime("2018-01-03 00:00:00", '%Y-%m-%d %H:%M:%S')
    
    # Define a circle with a 100km radius around the location point
    domain = DomainFromPoints(points=[(-63.6, 44.6)], radial_distances=[100000])
     
    def color_tracks2(tracks):
        colors = {}
        for track in tracks:
            mmsi = track.get('mmsi')
            if mmsi not in colors:
                # Assign a random color to this MMSI if not already assigned
                colors[mmsi] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
                track['color'] = colors[mmsi] # Set the color for the current track
            yield track
    
    def plot_tracks_with_contextily(tracks):
        plt.figure(figsize=(12, 8))
    
        for track in tracks:
            plt.plot(track['lon'], track['lat'], color=track['color'], linewidth=2)
    
        # Add basemap
        cx.add_basemap(plt.gca(), crs='EPSG:4326', source=cx.providers.CartoDB.Positron)
    
        plt.xlabel('Longitude')
        plt.ylabel('Latitude')
        plt.title('Vessel Tracks with Basemap')
        plt.show()
        
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
            ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
            callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
        )
        rowgen = qry.gen_qry()
        
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        colored_tracks = list(color_tracks2(tracks))
        
        plot_tracks_with_contextily(colored_tracks)
    from mpl_toolkits.basemap import Basemap
    import matplotlib.pyplot as plt
    
    def plot_tracks_with_basemap(tracks):
        plt.figure(figsize=(12, 8))
        # Define the geofence boundaries
        llcrnrlat = 42.854329883666175  # Latitude of the southwest corner
        urcrnrlat = 47.13666808816243   # Latitude of the northeast corner
        llcrnrlon = -68.73998377599209  # Longitude of the southwest corner
        urcrnrlon = -56.92378296577808  # Longitude of the northeast corner
    
        # Create the Basemap object with the geofence
        m = Basemap(projection='merc', 
                    llcrnrlat=llcrnrlat, urcrnrlat=urcrnrlat,
                    llcrnrlon=llcrnrlon, urcrnrlon=urcrnrlon, resolution='i')
        
        m.drawcoastlines()
        m.drawcountries()
        m.drawmapboundary(fill_color='aqua')
        m.fillcontinents(color='lightgreen', lake_color='aqua')
    
        for track in tracks:
            lons, lats = track['lon'], track['lat']
            x, y = m(lons, lats)
            m.plot(x, y, color=track['color'], linewidth=2)
    
        plt.show()
        
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
            ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
            callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
        )
        rowgen = qry.gen_qry()
        
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        colored_tracks = list(color_tracks2(tracks))
        
        plot_tracks_with_basemap(colored_tracks)
    import cartopy.crs as ccrs
    import matplotlib.pyplot as plt
    
    def plot_tracks_with_cartopy(tracks):
        plt.figure(figsize=(12, 8))
        ax = plt.axes(projection=ccrs.Mercator())
        ax.coastlines()
        
        for track in tracks:
            lons, lats = track['lon'], track['lat']
            ax.plot(lons, lats, transform=ccrs.PlateCarree(), color=track['color'], linewidth=2)
        
        plt.title('AIS Tracks Visualization with Cartopy')
        plt.show()
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
            ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
            callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
        )
        rowgen = qry.gen_qry()
        
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        colored_tracks = list(color_tracks2(tracks))
        
        plot_tracks_with_cartopy(colored_tracks)
    import plotly.graph_objects as go
    import plotly.express as px
    import pandas as pd
    
    def track2dataframe(tracks):
        data = []
        # Iterate over each track in the vessels_generator
        for track in tracks:
            # Unpack static information
            mmsi = track['mmsi']
            rot = track['rot']
            maneuver = track['maneuver']
            heading = track['heading']
        
            # Unpack dynamic information
            times = track['time']
            lons = track['lon']
            lats = track['lat']
            cogs = track['cog']
            sogs = track['sog']
            utc_seconds = track['utc_second']
        
            # Iterate over the dynamic arrays and create a row for each time point
            for i in range(len(times)):
                data.append({
                    'mmsi': mmsi,
                    'rot': rot,
                    'maneuver': maneuver,
                    'heading': heading,
                    'time': times[i],
                    'longitude': lons[i],
                    'latitude': lats[i],
                    'cog': cogs[i],
                    'sog': sogs[i],
                    'utc_second': utc_seconds[i],
                })
                
        # Convert the list of dictionaries to a pandas DataFrame
        df = pd.DataFrame(data)
        
        return df
    
    def plotly_visualize(data, visual_type='lines'):
        if (visual_type=='scatter'):
            # Create a scatter plot for the vessel data points using scatter_geo
            fig = px.scatter_geo(
                data,
                lat="latitude",
                lon="longitude",
                color="mmsi",  # Color by vessel identifier
                hover_name="mmsi",
                hover_data={"time": True},
                title="Vessel Data Points"
            )
        else:
            # Create a line plot for the vessel trajectory using scatter_geo
            fig = px.line_geo(
                data,
                lat="latitude",
                lon="longitude",
                color="mmsi",  # Color by vessel identifier
                hover_name="mmsi",
                hover_data={"time": True},
            )
        
        # Set the map style and projection
        fig.update_geos(
            projection_type="azimuthal equal area",  # Change this to 'natural earth', 'azimuthal equal area', etc.
            showland=True,
            landcolor="rgb(243, 243, 243)",
            countrycolor="rgb(204, 204, 204)",
            lonaxis=dict(range=[-68.73998377599209, -56.92378296577808]),  # Longitude range (geofence)
            lataxis=dict(range=[42.854329883666175, 47.13666808816243])   # Latitude range (geofence)
        )
        
        # Set the layout to focus on a specific area or zoom level
        fig.update_layout(
            geo=dict(
                projection_type="mercator",
                center={"lat": 44.5, "lon": -63.5},
            ),
            width=900,  # Increase the width of the plot
            height=700,  # Increase the height of the plot
        ) 
        fig.show()
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
            ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
            callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
        )
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
    
        df = track2dataframe(tracks)
        plotly_visualize(df, 'lines')
    import pandas as pd
    from keplergl import KeplerGl
    
    def visualize_with_kepler(data, config=None):
        map_1 = KeplerGl(height=600)
        map_1.add_data(data=data, name="AIS Data")
        map_1.save_to_html(file_name='./figure/kepler_map.html')
        
        return map_1
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
            ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
            callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
        )
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        
        df = track2dataframe(tracks)
        map_1 = visualize_with_kepler(df)
    import aisdb
    from aisdb.gis import dt_2_epoch
    
    y1, x1 = 44.57039426840729, -63.52931373766157
    y2, x2 = 44.51304767533133, -63.494075674952555
    y3, x3 = 44.458038982492134, -63.535634138077945
    y4, x4 = 44.393941339104074, -63.53826396955358
    y5, x5 = 44.14245580737021, -64.16608964280064
    
    t1 = dt_2_epoch( datetime(2021, 1, 1, 1) )
    t2 = dt_2_epoch( datetime(2021, 1, 1, 2) )
    t3 = dt_2_epoch( datetime(2021, 1, 1, 3) )
    t4 = dt_2_epoch( datetime(2021, 1, 1, 4) )
    t5 = dt_2_epoch( datetime(2021, 1, 1, 7) )
    
    # creating a sample track
    tracks_short = [
        dict(
            mmsi=123456789,
            lon=np.array([x1, x2, x3, x4, x5]),
            lat=np.array([y1, y2, y3, y4, y5]),
            time=np.array([t1, t2, t3, t4, t5]),
            dynamic=set(['lon', 'lat', 'time']),
            static=set(['mmsi'])
        )
    ]

    rot

    REAL

    Rate of turn, indicating how fast the vessel is turning.

    sog

    REAL

    Speed over ground, in knots.

    cog

    REAL

    Course over ground, in degrees.

    heading

    REAL

    Heading of the vessel, in degrees.

    maneuver

    BOOLEAN

    Indicator for whether the vessel is performing a special maneuver.

    utc_second

    INTEGER

    Second of the UTC timestamp when the message was generated.

    source

    TEXT

    Source of the AIS data.

    call_sign

    TEXT

    International radio call sign of the vessel.

    imo

    INTEGER

    International Maritime Organization number, another unique vessel identifier.

    dim_bow

    INTEGER

    Distance from the AIS transmitter to the bow (front) of the vessel.

    dim_stern

    INTEGER

    Distance from the AIS transmitter to the stern (back) of the vessel.

    dim_port

    INTEGER

    Distance from the AIS transmitter to the port (left) side of the vessel.

    dim_star

    INTEGER

    Distance from the AIS transmitter to the starboard (right) side of the vessel.

    draught

    REAL

    Maximum depth of the vessel's hull below the waterline, in meters.

    destination

    TEXT

    Destination port or location where the vessel is heading.

    ais_version

    INTEGER

    AIS protocol version used by the vessel.

    fixing_device

    TEXT

    Type of device used for fixing the vessel's position (e.g., GPS).

    eta_month

    INTEGER

    Estimated time of arrival month.

    eta_day

    INTEGER

    Estimated time of arrival day.

    eta_hour

    INTEGER

    Estimated time of arrival hour.

    eta_minute

    INTEGER

    Estimated time of arrival minute.

    source

    TEXT

    Source of the AIS data (e.g., specific AIS receiver or data provider).

    call_sign

    TEXT

    International radio call sign of the vessel.

    dim_bow

    INTEGER

    Distance from the AIS transmitter to the bow (front) of the vessel.

    dim_stern

    INTEGER

    Distance from the AIS transmitter to the stern (back) of the vessel.

    dim_port

    INTEGER

    Distance from the AIS transmitter to the port (left) side of the vessel.

    dim_star

    INTEGER

    Distance from the AIS transmitter to the starboard (right) side of the vessel.

    draught

    REAL

    Maximum depth of the vessel's hull below the waterline, in meters.

    destination

    TEXT

    Destination port or location where the vessel is heading.

    eta_month

    INTEGER

    Estimated time of arrival month.

    eta_day

    INTEGER

    Estimated time of arrival day.

    eta_hour

    INTEGER

    Estimated time of arrival hour.

    eta_minute

    INTEGER

    Estimated time of arrival minute.

    .
    TimescaleDB documentationarrow-up-right

    Clustering with Scikit Learn

    TGNs with TorchGeometric

    Kalman Filters with FilterPy

    Original data points of the vessel track queried from database
    Linear interpolation on the vessel track with equal time intervals
    Linear interpolation with equal distance intervals
    Linear interpolation of the vessel track along the geodesic curve
    Cubic spline interpolation with equal time intervals
    B-spline interpolation with equal distance intervals of 1 km
    Queried vessel tracks after applying track segmentation and encoder (distance threshold=20km, speed threshold=50knots)
    Queried vessel tracks before cleaning
    aisdb.webdata.shore_dist.ShoreDistarrow-up-right
    aisdb.webdata.shore_dist.PortDistarrow-up-right
    GitHub repositoryarrow-up-right
    MarineCadastrearrow-up-right
    [ 6961.401286 6948.59446128 7130.40147082 57279.94580704]
    denoising encoderarrow-up-right
    hashtag
    Session Request

    The session request method in Python is a straightforward and efficient approach for retrieving metadata from websites. In AISdb, the aisdb.webdata._scraper.search_metadata_vesselfinder function leverages this method to scrape detailed information about vessels based on their MMSI numbers. This function efficiently gathers a range of data, including vessel name, type, flag, tonnage, and navigation status.

    This is an example of how to use the search_metadata_vesselfinder feature in AISdb to scrape data from VesselFinderarrow-up-right website:

    hashtag
    MarineTraffic API

    In addition to metadata scraping, we may also use the available API the data provides. MarineTraffic offers an option to subscribe to its API to access vessel data, forecast voyages, position the vessels, etc. Here is an example of retrieving :

    hashtag
    Metadata Storage

    If you already have a database containing AIS track data, then vessel metadata can be downloaded and stored in a separate database.

    With the example track we created in Haversine Distance, we can calculate the vessel speed between each two consecutive positions:
    Speed(knot)=HaversineDistanceTime×1.9438445Speed(knot) = \frac{Haversine Distance}{Time} \times 1.9438445Speed(knot)=TimeHaversineDistance​×1.9438445
    denoising encoderarrow-up-right
    haversine distancearrow-up-right
    [3.7588560005768947 3.7519408684140214 3.8501088005116215 10.309565520121597]
    hashtag
    Export CSV from PostgreSQL Database

    Similar to exporting data from a SQLite database to a CSV file, the only difference this time is that you'll need to connect to your PostgreSQL database and query the data you want to export to CSV. We showed a full example as follows:

    We can check the output CSV file now:

    mmsi 	time 	lon 	lat 	cog 	sog 	utc_second 	heading 	rot 	maneuver
    0 	219014000 	1514767484 	-63.537167 	44.635834 	322 	0.0 	44 	295.0 	0.0 	0
    1 	219014000 	1514814284 	-63.537167 	44.635834 	119 	0.0 	45 	295.0 	0.0 	0
    2 	219014000 	1514829783 	-63.537167 	44.635834 	143 	0.0 	15 	295.0 	0.0 	0
    3 	219014000 	1514829843 	-63.537167 	44.635834 	171 	0.0 	15 	295.0 	0.0 	0
    4 	219014000 	1514830042 	-63.537167 	44.635834 	3 	0.0 	35 	295.0 	0.0 	0
    mmsi 	time 	lon 	lat 	cog 	sog 	utc_second 	heading 	rot 	maneuver
    0 	210108000 	1672545711 	-63.645 	44.68833 	173 	0.0 	0 	0.0 	0.0 	False
    1 	210108000 	1672545892 	-63.645 	44.68833 	208 	0.0 	0 	0.0 	0.0 	False
    2 	210108000 	1672546071 	-63.645 	44.68833 	176 	0.0 	0 	0.0 	0.0 	False
    3 	210108000 	1672546250 	-63.645 	44.68833 	50 	0.0 	0 	0.0 	0.0 	False
    4 	210108000 	1672546251 	-63.645 	44.68833 	50 	0.0 	0 	0.0 	0.0 	False
    import plotly.graph_objects as go
    import plotly.express as px
    import pandas as pd
    
    def track2dataframe(tracks):
        data = []
        for track in tracks:
            times = track['time']
            mmsi = track['mmsi']
            lons = track['lon']
            lats = track['lat']
        
            # Iterate over the dynamic arrays and create a row for each time point
            for i in range(len(times)):
                data.append({
                    'mmsi': mmsi,
                    'time': times[i],
                    'longitude': lons[i],
                    'latitude': lats[i],
                })
                
        return pd.DataFrame(data)
    
    def plotly_visualize(data, visual_type='lines'):
        if (visual_type=='scatter'):
            # Create a scatter plot for the vessel data points using scatter_geo
            fig = px.scatter_geo(
                data,
                lat="latitude",
                lon="longitude",
                color="mmsi",  # Color by vessel identifier
                hover_name="mmsi",
                hover_data={"time": True},
                title="Vessel Data Points"
            )
        else:
            # Create a line plot for the vessel trajectory using line_geo
            fig = px.line_geo(
                data,
                lat="latitude",
                lon="longitude",
                color="mmsi",  # Color by vessel identifier
                hover_name="mmsi",
                hover_data={"time": True},
                title="Vessel Trajectory"
            )
        
        # Set the map style and projection
        fig.update_geos(
            projection_type="azimuthal equal area",  # Change this to 'natural earth', 'azimuthal equal area', etc.
            showland=True,
            landcolor="rgb(243, 243, 243)",
            countrycolor="rgb(204, 204, 204)",
        )
        
        # Set the layout to focus on a specific area or zoom level
        fig.update_layout(
            geo=dict(
                projection_type="azimuthal equal area",
            ),
            width=1200,  # Increase the width of the plot
            height=800,  # Increase the height of the plot
        )
        
        fig.show()
    import aisdb
    import numpy as np
    import nest_asyncio
    from aisdb import DBConn, DBQuery
    from datetime import timedelta, datetime
    
    nest_asyncio.apply()
    dbpath='YOUR_DATABASE.db' # Define the path to your database
    
    MMSI = 636017611 # MMSI of the vessel
    
    # Set the start and end times for the query
    start_time = datetime.strptime("2018-03-10 00:00:00", '%Y-%m-%d %H:%M:%S')
    end_time = datetime.strptime("2018-03-31 00:00:00", '%Y-%m-%d %H:%M:%S')
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time, mmsi = MMSI,
            callback=aisdb.database.sqlfcn_callbacks.in_timerange_hasmmsi,
        )
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        
        # Visualize the original track data points
        df = track2dataframe(tracks)
        plotly_visualize(df, 'scatter')
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time, mmsi = MMSI,
            callback=aisdb.database.sqlfcn_callbacks.in_timerange_hasmmsi,
        )
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
    
        tracks__ = aisdb.interp.interp_time(tracks, timedelta(minutes=10))
    
        df = track2dataframe(tracks__)
        plotly_visualize(df)
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time, mmsi = MMSI,
            callback=aisdb.database.sqlfcn_callbacks.in_timerange_hasmmsi,
        )
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
    
        tracks__ = aisdb.interp.interp_spacing(spacing=500, tracks=tracks)
    
        # Visualizing the tracks
        df = track2dataframe(tracks__)
        plotly_visualize(df)
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time, mmsi = MMSI,
            callback=aisdb.database.sqlfcn_callbacks.in_timerange_hasmmsi,
        )
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
    
        tracks__ = aisdb.interp.geo_interp_time(tracks, timedelta(minutes=10))
    
        df = track2dataframe(tracks__)
        plotly_visualize(df)
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time, mmsi = MMSI,
            callback=aisdb.database.sqlfcn_callbacks.in_timerange_hasmmsi,
        )
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
    
        tracks__ = aisdb.interp.interp_cubic_spline(tracks, timedelta(minutes=10))
    
        # Visualizing the tracks
        df = track2dataframe(tracks__)
        plotly_visualize(df)
    import numpy as np
    from scipy.interpolate import splrep, splev
    
    def bspline_interpolation(track, key, intervals):
        """
        Perform B-Spline interpolation for a specific key on the track data.
    
        Parameters:
        - track: Dictionary containing vessel track data (time, lat, lon, etc.).
        - key: The dynamic key (e.g., 'lat', 'lon') for which interpolation is performed.
        - intervals: The equal time or distance intervals at which interpolation is required.
    
        Returns:
        - Interpolated values for the specified key.
        """
        # Get time and the key values (e.g., lat/lon) for interpolation
        times = track['time']
        values = track[key]
    
        # Create the B-Spline representation of the curve
        tck = splrep(times, values, s=0)  # s=0 means no smoothing, exact fit to data
    
        # Interpolate the values at the given intervals
        interpolated_values = splev(intervals, tck)
    
        return interpolated_values
    
    def interp_bspline(tracks, step=1000):
        """
        Perform B-Spline interpolation on vessel trajectory data at equal time intervals.
    
        Parameters:
        - tracks: List of vessel track dictionaries.
        - step: Step for interpolation (can be time or distance-based).
    
        Yields:
        - Dictionary containing interpolated lat and lon values for each track.
        """
        for track in tracks:
            if len(track['time']) <= 1:
                warnings.warn('Cannot interpolate track of length 1, skipping...')
                continue
    
            # Generate equal time intervals based on the first and last time points
            intervals = np.arange(track['time'][0], track['time'][-1], step)
    
            # Perform B-Spline interpolation for lat and lon
            interpolated_lat = bspline_interpolation(track, 'lat', intervals)
            interpolated_lon = bspline_interpolation(track, 'lon', intervals)
    
            # Yield interpolated track
            itr = dict(
                mmsi=track['mmsi'],
                lat=interpolated_lat,
                lon=interpolated_lon,
                time=intervals  # Including interpolated time intervals for reference
            )
            yield itr
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time, mmsi = MMSI,
            callback=aisdb.database.sqlfcn_callbacks.in_timerange_hasmmsi,
        )
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
    
        tracks__ = interp_bspline(tracks)
    
        # Visualizing the tracks
        df = track2dataframe(tracks__)
        plotly_visualize(df)
    import aisdb
    from datetime import datetime, timedelta
    from aisdb import DBConn, DBQuery, DomainFromPoints
    
    dbpath='YOUR_DATABASE.db' # Define the path to your database
    
    # Set the start and end times for the query
    start_time = datetime.strptime("2018-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
    end_time = datetime.strptime("2018-01-02 00:00:00", '%Y-%m-%d %H:%M:%S')
    
    # A circle with a 100km radius around the location point
    domain = DomainFromPoints(points=[(-63.6, 44.6)], radial_distances=[50000])
    
    maxdelta = timedelta(hours=24)  # the maximum time interval
    distance_threshold = 20000      # the maximum allowed distance (meters) between consecutive AIS messages
    speed_threshold = 50            # the maximum allowed vessel speed in consecutive AIS messages
    minscore = 1e-6                 # the minimum score threshold for track segment validation
    
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            callback=aisdb.database.sqlfcn_callbacks.in_timerange_validmmsi,
        )
        rowgen = qry.gen_qry()
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        
        # Split the tracks into segments based on the maximum time interval
        track_segments = aisdb.split_timedelta(tracks, maxdelta)
        
        # Encode the track segments to clean and validate the track data
        tracks_encoded = aisdb.encode_greatcircledistance(track_segments, 
                                                          distance_threshold=distance_threshold, 
                                                          speed_threshold=speed_threshold, 
                                                          minscore=minscore)
        tracks_colored = color_tracks(tracks_encoded)
        
        aisdb.web_interface.visualize(
            tracks_colored,
            domain=domain,
            visualearth=True,
            open_browser=True,
        )
    from aisdb.webdata.shore_dist import ShoreDist
    
    with ShoreDist(data_dir="./testdata/") as sdist:
            # Getting distance from shore for each point in the track
            for track in sdist.get_distance(tracks_short):
                assert 'km_from_shore' in track['dynamic']
                assert 'km_from_shore' in track.keys()
                print(track['km_from_shore'])
    [ 1  3  2  9 14]
    from aisdb.webdata.shore_dist import CoastDist
    
    with CoastDist(data_dir="./testdata/") as cdist:
            # Getting distance from the coast for each point in the track
            for track in cdist.get_distance(tracks_short):
                assert 'km_from_coast' in track['dynamic']
                assert 'km_from_coast' in track.keys()
                print(track['km_from_coast'])
    [ 1  3  2  8 13]
    from aisdb.webdata.shore_dist import PortDist
    
    with PortDist(data_dir="./testdata/") as pdist:
            # Getting distance from the port for each point in the track
            for track in pdist.get_distance(tracks_short):
                assert 'km_from_port' in track['dynamic']
                assert 'km_from_port' in track.keys()
                print(track['km_from_port'])
    [ 4.72144175  7.47747231  4.60478449 11.5642271  28.62511253]
    curl -o ./data/AIS_2020_01_01.zip https://coast.noaa.gov/htdata/CMSP/AISDataHandler/2020/AIS_2020_01_01.zip
    unzip ./data/AIS_2020_01_01.zip -d ./data/
    import pandas as pd
    
    # Read CSV file in pandas dataframe
    df_ = pd.read_csv("./data/AIS_2020_01_01.csv", parse_dates=["BaseDateTime"])
    
    print(df_.columns)
    Index(['MMSI', 'BaseDateTime', 'LAT', 'LON', 'SOG', 'COG', 'Heading',
           'VesselName', 'IMO', 'CallSign', 'VesselType', 'Status',
           'Length', 'Width', 'Draft', 'Cargo', 'TransceiverClass'],
          dtype='object')
    list_of_headers_ = ["MMSI","Message_ID","Repeat_indicator","Time","Millisecond","Region","Country","Base_station","Online_data","Group_code","Sequence_ID","Channel","Data_length","Vessel_Name","Call_sign","IMO","Ship_Type","Dimension_to_Bow","Dimension_to_stern","Dimension_to_port","Dimension_to_starboard","Draught","Destination","AIS_version","Navigational_status","ROT","SOG","Accuracy","Longitude","Latitude","COG","Heading","Regional","Maneuver","RAIM_flag","Communication_flag","Communication_state","UTC_year","UTC_month","UTC_day","UTC_hour","UTC_minute","UTC_second","Fixing_device","Transmission_control","ETA_month","ETA_day","ETA_hour","ETA_minute","Sequence","Destination_ID","Retransmit_flag","Country_code","Functional_ID","Data","Destination_ID_1","Sequence_1","Destination_ID_2","Sequence_2","Destination_ID_3","Sequence_3","Destination_ID_4","Sequence_4","Altitude","Altitude_sensor","Data_terminal","Mode","Safety_text","Non-standard_bits","Name_extension","Name_extension_padding","Message_ID_1_1","Offset_1_1","Message_ID_1_2","Offset_1_2","Message_ID_2_1","Offset_2_1","Destination_ID_A","Offset_A","Increment_A","Destination_ID_B","offsetB","incrementB","data_msg_type","station_ID","Z_count","num_data_words","health","unit_flag","display","DSC","band","msg22","offset1","num_slots1","timeout1","Increment_1","Offset_2","Number_slots_2","Timeout_2","Increment_2","Offset_3","Number_slots_3","Timeout_3","Increment_3","Offset_4","Number_slots_4","Timeout_4","Increment_4","ATON_type","ATON_name","off_position","ATON_status","Virtual_ATON","Channel_A","Channel_B","Tx_Rx_mode","Power","Message_indicator","Channel_A_bandwidth","Channel_B_bandwidth","Transzone_size","Longitude_1","Latitude_1","Longitude_2","Latitude_2","Station_Type","Report_Interval","Quiet_Time","Part_Number","Vendor_ID","Mother_ship_MMSI","Destination_indicator","Binary_flag","GNSS_status","spare","spare2","spare3","spare4"]
    # Take the first 40,000 records from the original dataframe
    df = df_.iloc[0:40000]
    
    # Create a new dataframe with the specified headers
    df_new = pd.DataFrame(columns=list_of_headers_)
    
    # Populate the new dataframe with formatted data from the original dataframe
    df_new['Time'] = pd.to_datetime(df['BaseDateTime']).dt.strftime('%Y%m%d_%H%M%S')
    df_new['Latitude'] = df['LAT']
    df_new['Longitude'] = df['LON']
    df_new['Vessel_Name'] = df['VesselName']
    df_new['Call_sign'] = df['CallSign']
    df_new['Ship_Type'] = df['VesselType'].fillna(0).astype(int)
    df_new['Navigational_status'] = df['Status']
    df_new['Draught'] = df['Draft']
    df_new['Message_ID'] = 1  # Mark all messages as dynamic by default
    df_new['Millisecond'] = 0
    
    # Transfer additional columns from the original dataframe, if they exist
    for col_n in df_new:
        if col_n in df.columns:
            df_new[col_n] = df[col_n]
    
    # Extract static messages for each unique vessel
    filtered_df = df_new[df_new['Ship_Type'].notnull() & (df_new['Ship_Type'] != 0)]
    filtered_df = filtered_df.drop_duplicates(subset='MMSI', keep='first')
    filtered_df = filtered_df.reset_index(drop=True)
    filtered_df['Message_ID'] = 5  # Mark these as static messages
    
    # Merge dynamic and static messages into a single dataframe
    df_new = pd.concat([filtered_df, df_new])
    
    # Save the final dataframe to a CSV file
    # The quoting parameter is necessary because the csvreader reads each column value as a string by default
    df_new.to_csv("./data/AIS_2020_01_01_aisdb.csv", index=False, quoting=1)
    import aisdb
    
    # Establish a connection to the SQLite database and decode messages from the CSV file
    with aisdb.SQLiteDBConn('./data/test_decode_msgs.db') as dbconn:
            aisdb.decode_msgs(filepaths=["./data/AIS_2020_01_01_aisdb.csv"],
                              dbconn=dbconn, source='Testing', verbose=True)
    generating file checksums...
    checking file dates...
    creating tables and dropping table indexes...
    Memory: 20.65GB remaining.  CPUs: 12.  Average file size: 49.12MB  Spawning 4 workers
    saving checksums...
    processing ./data/AIS_2020_01_01_aisdb.csv
    AIS_2020_01_01_aisdb.csv                                         count:   49323    elapsed:    0.27s    rate:   183129 msgs/s
    cleaning temporary data...
    aggregating static reports into static_202001_aggregate...
    sqlite3 ./data/test_decode_msgs.db
    
    sqlite> .tables
    ais_202001_dynamic       coarsetype_ref           static_202001_aggregate
    ais_202001_static        hashmap 
    // Some code
    python 0-download-ais.py
    python 1-zip2csv.py
    python 2-merge.py
    python 3-deduplicate.py
    python 4-postgresql-database.py -dbname DBNAME -user USERNAME -password PASSWORD [-host HOST] [-port PORT]
    psql -U USERNAME -d DBNAME -h localhost -p 5432
    ais_pgdb=# \dt
               List of relations
     Schema |    Name     | Type  |  Owner  
    --------+-------------+-------+----------
     public | ais_2023_01 | table | postgres
     public | ais_2023_02 | table | postgres
     public | ais_2023_03 | table | postgres
     public | ais_2023_04 | table | postgres
     public | ais_2023_05 | table | postgres
     public | ais_2023_06 | table | postgres
     public | ais_2023_07 | table | postgres
     public | ais_2023_08 | table | postgres
     public | ais_2023_09 | table | postgres
     public | ais_2023_10 | table | postgres
     public | ais_2023_11 | table | postgres
     public | ais_2023_12 | table | postgres
    (12 rows) 
    import aisdb
    import numpy as np
    from aisdb.gis import dt_2_epoch
    from datetime import datetime
    
    y1, x1 = 44.57039426840729, -63.52931373766157
    y2, x2 = 44.51304767533133, -63.494075674952555
    y3, x3 = 44.458038982492134, -63.535634138077945
    y4, x4 = 44.393941339104074, -63.53826396955358
    y5, x5 = 44.14245580737021, -64.16608964280064
    
    t1 = dt_2_epoch( datetime(2021, 1, 1, 1) )
    t2 = dt_2_epoch( datetime(2021, 1, 1, 2) )
    t3 = dt_2_epoch( datetime(2021, 1, 1, 3) )
    t4 = dt_2_epoch( datetime(2021, 1, 1, 4) )
    t5 = dt_2_epoch( datetime(2021, 1, 1, 7) )
    
    # Create a sample track
    tracks_short = [
        dict(
            lon=np.array([x1, x2, x3, x4, x5]),
            lat=np.array([y1, y2, y3, y4, y5]),
            time=np.array([t1, t2, t3, t4, t5]),
            mmsi=123456789,
            dynamic=set(['lon', 'lat', 'time']),
            static=set(['mmsi'])
        )
    ]
    
    # Calculate the Haversine distance
    for track in tracks_short:
        print(aisdb.gis.delta_meters(track))
    from aisdb.webdata._scraper import search_metadata_vesselfinder
    
    MMSI = 228386800
    dict_ = search_metadata_vesselfinder(MMSI)
    
    print(dict_)
    {'IMO number': '9839131',
     'Vessel Name': 'CMA CGM CHAMPS ELYSEES',
     'Ship type': 'Container Ship',
     'Flag': 'France',
     'Homeport': '-',
     'Gross Tonnage': '236583',
     'Summer Deadweight (t)': '220766',
     'Length Overall (m)': '400',
     'Beam (m)': '61',
     'Draught (m)': '',
     'Year of Build': '2020',
     'Builder': '',
     'Place of Build': '',
     'Yard': '',
     'TEU': '',
     'Crude Oil (bbl)': '-',
     'Gas (m3)': '-',
     'Grain': '-',
     'Bale': '-',
     'Classification Society': '',
     'Registered Owner': '',
     'Owner Address': '',
     'Owner Website': '-',
     'Owner Email': '-',
     'Manager': '',
     'Manager Address': '',
     'Manager Website': '',
     'Manager Email': '',
     'Predicted ETA': '',
     'Distance / Time': '',
     'Course / Speed': '\xa0',
     'Current draught': '16.0 m',
     'Navigation Status': '\nUnder way\n',
     'Position received': '\n22 mins ago \n\n\n',
     'IMO / MMSI': '9839131 / 228386800',
     'Callsign': 'FLZF',
     'Length / Beam': '399 / 61 m'}
    import requests
    
    # Your MarineTraffic API key
    api_key = 'your_marine_traffic_api_key'
    
    # List of MMSI numbers you want to query
    mmsi = [228386800,
            372351000,
            373416000,
            477003800,
            477282400
    ]
    
    # Base URL for the MarineTraffic API endpoint
    url = f'https://services.marinetraffic.com/api/exportvessels/{api_key}'
    
    # Prepare the API request
    params = {
        'shipid': ','.join(mmsi_list),  # Join MMSI list with commas
        'protocol': 'jsono',            # Specify the response format
        'msgtype': 'extended'           # Specify the level of details
    }
    
    # Make the API request
    response = requests.get(url, params=params)
    
    # Check if the request was successful
    if response.status_code == 200:
        vessel_data = response.json()
        
        for vessel in vessel_data:
            print(f"Vessel Name: {vessel.get('NAME')}")
            print(f"MMSI: {vessel.get('MMSI')}")
            print(f"IMO: {vessel.get('IMO')}")
            print(f"Call Sign: {vessel.get('CALLSIGN')}")
            print(f"Type: {vessel.get('TYPE_NAME')}")
            print(f"Flag: {vessel.get('COUNTRY')}")
            print(f"Length: {vessel.get('LENGTH')}")
            print(f"Breadth: {vessel.get('BREADTH')}")
            print(f"Year Built: {vessel.get('YEAR_BUILT')}")
            print(f"Status: {vessel.get('STATUS_NAME')}")
            print('-' * 40)
    else:
        print(f"Failed to retrieve data: {response.status_code}")
    from aisdb import track_gen, decode_msgs, DBQuery, sqlfcn_callbacks, Domain
    from datetime import datetime 
    
    dbpath = "/home/database.db"
    start = datetime(2021, 11, 1)
    end = datetime(2021, 11, 2)
    
    with DBConn(dbpath="/home/data_sample_dynamic.csv.db") as dbconn:
            qry = aisdb.DBQuery(
                 dbconn=dbconn, callback=in_timerange,
                 start=datetime(2020, 1, 1, hour=0),
                 end=datetime(2020, 12, 3, hour=2)
            )
            # A new database will be created if it does not exist to save the downloaded info from MarineTraffic
            traffic_database_path = "/home/traffic_info.db"
           
            # User can select a custom boundary for a query using aisdb.Domain
            qry.check_marinetraffic(trafficDBpath=, boundary={"xmin":-180, "xmax":180, "ymin":-180,  "ymax":180})
            
            rowgen = qry.gen_qry(verbose=True)
            trackgen = track_gen.TrackGen(rowgen, decimate=True)
    import aisdb
    import numpy as np
    from datetime import datetime
    from aisdb.gis import dt_2_epoch
    
    # Generate example track
    y1, x1 = 44.57039426840729, -63.52931373766157
    y2, x2 = 44.51304767533133, -63.494075674952555
    y3, x3 = 44.458038982492134, -63.535634138077945
    y4, x4 = 44.393941339104074, -63.53826396955358
    y5, x5 = 44.14245580737021, -64.16608964280064
    
    t1 = dt_2_epoch( datetime(2021, 1, 1, 1) )
    t2 = dt_2_epoch( datetime(2021, 1, 1, 2) )
    t3 = dt_2_epoch( datetime(2021, 1, 1, 3) )
    t4 = dt_2_epoch( datetime(2021, 1, 1, 4) )
    t5 = dt_2_epoch( datetime(2021, 1, 1, 7) )
    
    # Create a sample track
    tracks_short = [
        dict(
            mmsi=123456789,
            lon=np.array([x1, x2, x3, x4, x5]),
            lat=np.array([y1, y2, y3, y4, y5]),
            time=np.array([t1, t2, t3, t4, t5]),
            dynamic=set(['lon', 'lat', 'time']),
            static=set(['mmsi'])
        )
    ]
    
    # Calculate the vessel speed in knots
    for track in tracks_short:
        print(aisdb.gis.delta_knots(track))
    import csv
    import aisdb
    import nest_asyncio
    
    from aisdb import DBConn, DBQuery, DomainFromPoints
    from aisdb.database.dbconn import SQLiteDBConn
    from datetime import datetime
    
    nest_asyncio.apply()
    
    dbpath = 'YOUR_DATABASE.db' # Path to your database
    end_time = datetime.strptime("2018-01-02 00:00:00", '%Y-%m-%d %H:%M:%S')
    start_time = datetime.strptime("2018-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
    domain = DomainFromPoints(points=[(-63.6, 44.6)], radial_distances=[50000])
    
    # Connect to SQLite database
    dbconn = SQLiteDBConn(dbpath=dbpath)
    
    with SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
            ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
            callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
        )
        tracks = aisdb.track_gen.TrackGen(qry.gen_qry(), decimate=False)
    
    # Define the headers for the CSV file
    headers = ['mmsi', 'time', 'lon', 'lat', 'cog', 'sog',
               'utc_second', 'heading', 'rot', 'maneuver']
    
    # Open the CSV file for writing
    csv_filename = 'output_sqlite.csv'
    with open(csv_filename, mode='w', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=headers)
        writer.writeheader()  # Write the header once
        
        for track in tracks:
            for i in range(len(track['time'])):
                row = {
                    'rot': track['rot'],
                    'mmsi': track['mmsi'],
                    'lon': track['lon'][i],
                    'lat': track['lat'][i],
                    'cog': track['cog'][i],
                    'sog': track['sog'][i],
                    'time': track['time'][i],
                    'heading': track['heading'],
                    'maneuver': track['maneuver'],
                    'utc_second': track['utc_second'][i],
                }
                writer.writerow(row)  # Write the row to the CSV file
    
    print(f"All tracks have been combined and written to {csv_filename}")
    import csv
    import aisdb
    import nest_asyncio
    
    from datetime import datetime
    from aisdb.database.dbconn import PostgresDBConn
    from aisdb import DBConn, DBQuery, DomainFromPoints
    
    nest_asyncio.apply()
    
    dbconn = PostgresDBConn(
        host='localhost',          # PostgreSQL address
        port=5432,                 # PostgreSQL port
        user='your_username',      # PostgreSQL username
        password='your_password',  # PostgreSQL password
        dbname='database_name'     # Database name
    )
    
    qry = DBQuery(
        dbconn=dbconn,
        start=datetime(2023, 1, 1), end=datetime(2023, 1, 3),
        xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
        ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
        callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi
    )
    
    tracks = aisdb.track_gen.TrackGen(qry.gen_qry(), decimate=False)
    
    # Define the headers for the CSV file
    headers = ['mmsi', 'time', 'lon', 'lat', 'cog', 'sog',
               'utc_second', 'heading', 'rot', 'maneuver']
    
    # Open the CSV file for writing
    csv_filename = 'output_postgresql.csv'
    with open(csv_filename, mode='w', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=headers)
        writer.writeheader()  # Write the header once
        
        for track in tracks:
            for i in range(len(track['time'])):
                row = {
                    'rot': track['rot'],
                    'mmsi': track['mmsi'],
                    'lon': track['lon'][i],
                    'lat': track['lat'][i],
                    'cog': track['cog'][i],
                    'sog': track['sog'][i],
                    'time': track['time'][i],
                    'heading': track['heading'],
                    'maneuver': track['maneuver'],
                    'utc_second': track['utc_second'][i],
                }
                writer.writerow(row)  # Write the row to the CSV file
    
    print(f"All tracks have been combined and written to {csv_filename}")
    Decimation, in simple terms, means reducing the number of data points. When applied to AIS tracks, it involves selectively removing GPS points from a vessel’s trajectory while preserving its overall shape and key characteristics. Rather than processing every recorded position, decimation algorithms identify and retain the most relevant points, optimizing data efficiency without significant loss of accuracy.

    Think of it like simplifying a drawing: instead of using thousands of tiny dots to represent a complex image, you can use fewer, strategically chosen points to capture its essence. Similarly, decimation ensures that a vessel’s path with fewer points while maintaining its core trajectory, making analysis and visualization more efficient.

    hashtag
    Why Decimate AIS Data?

    There are several key benefits for using decimation techniques when working with AIS data:

    1. Improved Performance and Efficiency: Reducing the number of data points can dramatically decrease the computational load, enabling faster analyses, quicker visualizations, and more effective workflow, especially when dealing with large datasets.

    2. Clearer Visualizations: Dense tracks can clutter visualizations and make it difficult to interpret the data. Decimation simplifies the tracks, emphasizing on significant movements and patterns for more intuitive analysis.

    3. Noise Reduction: While decimation is not designed as a noise removal technique, it can help smooth out minor inaccuracies and high-frequency fluctuations from raw GPS data. This can be useful for focusing on broader trends and vessel movements.

    hashtag
    AISdb and simplify_linestring_idx()

    In AISDB, TrackGen() method includes adecimate parameter that, when set as True, triggers the simplify_linestring_idx(x, y, precision)function. This function uses the Visvalingam-Whyatt algorithmarrow-up-right to simplify vessel tracks while preserving key trajectory details.

    hashtag
    How the Visvalingam-Whyatt Algorithm Works

    The Visvalingam-Whyatt algorithm is an approach to line simplification. It works by removing points that contribute the least to the overall shape of the line. Here’s how it works:

    • The algorithm measures the importance of a point by calculating the area of the triangle formed by that point and its adjacent points.

    • Points on relatively straight segments form smaller triangles, meaning they’re less important in defining the shape.

    • Points at curves and corners form larger triangles, signaling that they’re crucial for maintaining the line’s characteristic form.

    The algorithm iteratively removes the points with the smallest triangle areas until the desired level of simplification is achieved. In AISdb, this process is controlled by the decimate parameter in the TrackGen() method.

    hashtag
    Using TrackGen(...,decimate = True) with AISDB Tracks

    Below is a conceptual Python example that demonstrates how to apply decimation to AIS tracks:

    hashtag
    Using simplify_linestring_idx() with AISDB Tracks

    To get more control over the precision for decimation, use function: simplify_linestring_idx in AISdb.

    hashtag
    Illustration of Decimation

    hashtag
    Key Parameters and Usage Notes:

    • Precision: The precision parameter controls the level of simplification. A smaller value (e.g., 0.001) results in more retained points and higher fidelity, while a larger value (e.g., 0.1) simplifies the track further with fewer points.

    • x, y: These are NumPy arrays representing the longitude and latitude coordinates of the track points.

    • TrackGen Integration: Decimation is applied after generating tracks with aisdb.TrackGen, followed by the application of simplify_linestring_idx() to each track individually.

    • Iterative Refinement: Decimation is often an iterative process. You may need to visualize the decimated tracks, assess the level of simplification, and adjust the precision to balance simplification with data fidelity.

    hashtag
    Conclusion

    Decimation is a powerful tool for simplifying and decluttering AIS data. By intelligently reducing the data’s complexity, AISDB’s simplify_linestring_idx() and TrackGen()allows you to process data more efficiently, create clearer visualizations, and gain deeper insights from your maritime data. Experiment with different precision values, and discover how “less” data can lead to “more” meaningful results in your AIS analysis workflows!

    hashtag
    References

    1. Amigo D, Sánchez Pedroche D, García J, Molina JM. Review and classification of trajectory summarisation algorithms: From compression to segmentation. International Journal of Distributed Sensor Networks. 2021;17(10). doi:10.1177/15501477211050729arrow-up-right

    Bathymetric Data

    hashtag
    Process AIS data with Bathymetric Data

    This section demonstrates integrating AIS data with external bathymetric data to enrich our analysis. In the following example, we identified all vessels within a 500-kilometer radius around the central area of Halifax, Canada, on January 1, 2018.

    hashtag
    Raster file preparation

    First, we imported the necessary packages and prepared the bathymetry data. It’s important to note that the downloaded bathymetric data is divided into eight segments, organized by latitude and longitude. In a later step, you will need to select the appropriate bathymetric raster file based on the geographical region covered by your vessel track data.

    hashtag
    Coloring the tracks

    We defined a coloring criterion to classify tracks based on their average depths relative to the bathymetry. Tracks that traverse shallow waters with an average depth of less than 100 meters are colored in yellow. Those spanning depths between 100 and 1,000 meters are represented in orange, indicating a transition to deeper waters. As the depth increases, tracks reaching up to 20 kilometers are marked pink. The deepest tracks, descending beyond 20 kilometers, are distinctly colored in red.

    hashtag
    Integration with the bathymetric raster file

    Next, we query the AIS data to be integrated with the bathymetric raster file and apply the coloring function to mark the tracks based on their average depths relative to the bathymetry.

    The integrated results are color-coded and can be visualized as shown below:

    Example of using bathymetry data to color-code vessel tracks based on their average depth:

    • Yellow: Tracks with an average depth of less than 100 meters (shallow waters).

    • Orange: Tracks with an average depth between 100 and 1,000 meters (transition to deeper waters).

    • Pink: Tracks with an average depth between 1,000 and 20,000 meters (deeper waters).

    Hexagon Discretization

    In this page, we will see how we can use AISDb to discretize AIS tracks to hexagons.

    hashtag
    Introduction

    Hexagonal Geospatial Indexing (H3): Uber’s hierarchical hexagonal geospatial indexing system, partitions the Earth into a multi-resolution hexagonal grid. Its key advantage over square grids is the “one-distance rule,” where all neighbors of a hexagon lie at comparable step distances.

    As illustrated in the figure above, this uniformity removes the diagonal-versus-edge ambiguity present in square lattices. For maritime work, hexagons are great because they reduce directional bias and make neighborhood queries and aggregation intuitive.

    Note: H3 indexes are 64-bit IDs typically shown as hex strings like “860e4d31fffffff.”

    hashtag
    Discretize AIS Lat/Lon points to hexagons using AISDb

    The code below provides a complete example of how to connect to a database of AIS data using AISDb and generate the corresponding H3 index for each data point.

    Refer to the example notebook here:

    hashtag
    References

    Automatic Identification System

    The Automatic Identification System (AIS) is a standardized and unencrypted self-reporting maritime surveillance system.

    hashtag
    How does this work?

    The protocol operates by transmitting one or more of 27 message types from an AIS transponder onboard a vessel at fixed time intervals. These intervals depend on the vessel’s status—stationary vessels (anchored or moored) transmit every 3 minutes, while fast-moving vessels transmit every 2 seconds.

    These VHF radio messages are sent from the vessel’s transponder and received by either satellite or ground-based stations, enabling more detailed monitoring and analysis of maritime traffic.

    hashtag
    Types of AIS messages

    1. Dynamic messages convey the vessel's real-time status, which can vary between transmissions. These include data such as Speed Over Ground (SOG), Course Over Ground (COG), Rate of Turn (ROT), and the vessel’s current position (latitude and longitude).

    2. Static messages, on the other hand, provide information that remains constant over time. This includes details like the Maritime Mobile Service Identity (MMSI), International Maritime Organization (IMO) number, vessel name, call sign, type, dimensions, and intended destination.

    hashtag
    Limitations of AIS signals

    1. Signals from vessels are lost.

    2. Terrestrial base stations are limited by their physical range, while satellite AIS receivers are limited based on their position globally.

    To learn more about AIS, refer to:

    hashtag
    References

    1. Brousseau, M. (2022). A comprehensive analysis and novel methods for on-purpose AIS switch-off detection [Master’s thesis, Dalhousie University]. DalSpace.

    2. Kazim, T. (2016, November 14). A definitive guide to AIS. MarineLink. Retrieved May 14, 2025, from

    Weather Data

    This tutorial introduces integrating weather data from GRIB files with AIS data for enhanced vessel tracking analysis. Practical examples are provided below illustrating how to integrate AISdb tracks with the weather data in GRIB files.

    To directly work with the jupyter notebook, click here: https://github.com/AISViz/AISdb/blob/master/examples/weather.ipynbarrow-up-right

    hashtag
    Prerequisites

    Before diving in, users are expected to have the following set-up: a Copernicus CDS account (free) to access ERA5 data, which can be obtained through the ECMWF-Signuparrow-up-right, and AISdb set up either locally or remotely. Refer to the AISDB-Installationarrow-up-right documentation for detailed instructions and configuration options.

    hashtag
    Usage:

    Once you have a Copernicus CDS account and AISdb installed, you can download weather data in GRIB format directly from the CDS and use AISdb to extract specific variables from those files.

    AISdb supports both zipped (.zip) and uncompressed GRIB (.grib) files. These files should be named using the yyyy-mm format (e.g., 2023-03.grib or 2023-03.zip) and placed in a folder such as /home/CanadaV2.

    With the WeatherDataStore class in AISdb, you can specify:

    • The desired weather variables (e.g., '10u', '10v' ,tp ),

    • A date range (e.g., August 1 to August 30, 2023),

    To automatically download GRIB files from the Copernicus Climate Data Store (CDS), AISdb provides a convenient option using the WeatherDataStore class.

    Simply set the parameter download_from_cds=True, and specify the required weather variable short names, date range, target area, and output path.

    To extract weather data for specific latitude, longitude, and timestamp values, call the yield_tracks_with_weather() method. This returns a dictionary containing the requested weather variables for each location-time pair.

    Output:

    hashtag
    What are short_names?

    In ECMWF (European Centre for Medium-Range Weather Forecasts) terminology, a "short name" is a concise, often abbreviated, identifier used to represent a specific meteorological parameter or variable within their data files (like GRIB files). It typically refers to a concise identifier used for climate and weather data variables. For example, "t2m" is a short name for "2-meter temperature".

    For a list of short names for different weather components, refer to:

    hashtag
    Complete Walkthrough Over an Example

    Let's work on an example where we retrieve AIS tracks from AISdb , call WeatherDataStore to add weather data to the tracks.

    hashtag
    Step 1: Import all necessary packages

    hashtag
    Step 2: Connect to the database

    hashtag
    Step 3: Query the required tracks

    Specify the region and duration for which you wish the tracks to be generated. The TrackGen returns a generator containing all the dynamic and static column values of AIS data.

    hashtag
    Step 4: Specify the necessary weather components using short_name convention

    Next, to merge the tracks with the weather data, we need to use the WeatherDataStore class from aisdb.weather.era5. By calling the WeatherDataStore with the required weather components (using their short names) and providing the path to your GRIB file containing the weather dataset, it will open the file and return an object through which you can perform further operations.

    Here, 10v and 10u are the and the .

    hashtag
    Step 5: Fetch Weather Values for a given latitude ,longiture and time

    By using the method weather_data_store.yield_tracks_with_weather(tracks), the tacks are concatenated with weather data.

    Example usage:

    hashtag
    Why do we need to integrate weather data with AIS data?

    By integrating weather data with AIS data , we can study the patterns of ship movement in relation to weather conditions. This integration allows us to analyze how factors such as wind, sea surface temperature, and atmospheric pressure influence vessel trajectories. By merging dynamic AIS data with detailed climate variables, we can gain deeper insights into the environmental factors affecting shipping routes and optimize vessel navigation for better efficiency and safety.

    AutoEncoders in Keras

    Trajectory Forecasting with Gate Recurrent Units AutoEncoders

    hashtag
    Introduction

    By the end of this tutorial, you will understand the benefits of using teacher forcing to improve model accuracy, as well as other tweaks to enhance forecasting capabilities. We'll use AutoEncoders, neural networks that learn compressed data representations, to achieve this.

    We will guide you through preparing AIS data for training an AutoEncoder, setting up layers, compiling the model, and defining the training process with teacher forcing.

    Given the complexity of this task, we will revisit it to explore the benefits of teacher forcing, a technique that can improve sequence-to-sequence learning in neural networks.

    This tutorial focuses on Trajectory Forecasting, which predicts an object's future path based on past positions. We will work with AIS messages, a type of temporal data that provides information about vessels' location, speed, and heading over time.

    hashtag
    AISdb Querying

    Automatic Identification System (AIS) messages broadcast essential ship information such as position, speed, and course. The temporal nature of these messages is pivotal for our tutorial, where we'll train an auto-encoder neural network for trajectory forecasting. This task involves predicting a ship's future path based on its past AIS messages, making it ideal for auto-encoders, which are optimized for learning patterns in sequential data.

    For querying the entire database at once, use the following code:

    For querying the database in batches of hours, use the following code:

    Several functions were defined using AISdb, an AIS framework developed by MERIDIAN at Dalhousie University, to efficiently extract AIS messages from SQLite databases. AISdb is designed for effective data storage, retrieval, and preparation for AIS-related tasks. It provides comprehensive tools for interacting with AIS data, including APIs for data reading and writing, parsing AIS messages, and performing various data transformations.

    hashtag
    Data Visualization

    Our next step is to create a coverage map of Atlantic Canada to visualize our dataset. We will include a 100km radius circle on the map to show the areas of the ocean where vessels can send AIS messages. Although overlapping circles may contain duplicate data from the same MMSI, we have already eliminated those from our dataset. However, messages might still appear incorrectly in inland areas.

    Loading a shapefile to help us define whether a vessel is on land or in water during the trajectory:

    Check if a given coordinate (latitude, longitude) is on land:

    Check if any coordinate of a track is on land:

    Filter out tracks with any point on land for a given MMSI:

    Use a ThreadPoolExecutor to parallelize the processing of MMSIs:

    Count the number of segments per MMSI after removing duplicates and inaccurate track segments:

    hashtag
    Dataset Preparation

    In this analysis, we observe that most MMSIs in the dataset exhibit between 1 and 49 segments during the search period within AISdb. However, a minor fraction of vessels have significantly more segments, with some reaching up to 176. Efficient processing involves categorizing the data by MMSI instead of merely considering its volume. This method allows us to better evaluate the model's ability to discern various movement behaviors from both the same vessel and different ones.

    To prevent our model from favoring shorter trajectories, we need a balanced mix of short-term and long-term voyages in the training and test sets. We'll categorize trajectories with 30 or more segments as long-term and those with fewer segments as short-term. Implement an 80-20 split strategy to ensure an equitable distribution of both types in the datasets.

    Splitting the data respecting the voyage length distribution:

    Visualizing the distribution of the dataset:

    hashtag
    Inputs & Outputs

    Understanding input and output timesteps and variables is crucial in trajectory forecasting tasks. Trajectory data comprises spatial coordinates and related features that depict an object's movement over time. The aim is to predict future positions of the object based on its historical data and associated features.

    • INPUT_TIMESTEPS: This parameter determines the consecutive observations used to predict future trajectories. Its selection impacts the model's ability to capture temporal dependencies and patterns. Too few time steps may prevent the model from capturing all movement dynamics, resulting in inaccurate predictions. Conversely, too many time steps can add noise and complexity, increasing the risk of overfitting.

    • INPUT_VARIABLES: Features describe each timestep in the input sequence for trajectory forecasting. These variables can include spatial coordinates, velocities, accelerations, object types, and relevant features that aid in predicting system dynamics. Choosing the right input variables is crucial; irrelevant or redundant ones may confuse the model while missing important variables can result in poor predictions.

    Understanding the roles of input and output timesteps and variables is key to developing accurate trajectory forecasting models. By carefully selecting these elements, we can create models that effectively capture object movement dynamics, resulting in more accurate and meaningful predictions across various applications.

    For this tutorial, we'll input 4 hours of data into the model to forecast the next 8 hours of vessel movement. Consequently, we'll filter out all voyages with less than 12 hours of AIS messages. By interpolating the messages every 5 minutes, we require a minimum of 144 sequential messages (12 hours at 12 messages/hour).

    With data provided by AISdb, we have AIS information, including Longitude, Latitude, Course Over Ground (COG), and Speed Over Ground (SOG), representing a ship's position and movement. Longitude and Latitude specify the ship's location, while COG and SOG indicate its heading and speed. By using all features for training the neural network, our output will be the Longitude and Latitude pair. This methodology allows the model to predict the ship's future positions based on historical data.

    In this tutorial, we'll include AIS data deltas as features, which were excluded in the previous tutorial. Incorporating deltas can help the model capture temporal changes and patterns, enhancing its effectiveness in sequence-to-sequence modeling. Deltas provides information on the rate of change in features, improving the model's accuracy, especially in predicting outcomes that depend on temporal dynamics.

    hashtag
    Data Filtering

    hashtag
    Data Statistics

    hashtag
    Sample Weighting

    hashtag
    Distance Function

    To improve our model, we'll prioritize training samples based on trajectory straightness. We'll compute the geographical distance between a segment's start and end points using the Haversine formula. Comparing this to the total distance of all consecutive points will give a straightness metric. Our model will focus on complex trajectories with multiple direction changes, leading to better generalization and more accurate predictions.

    hashtag
    Complexity Score

    Trajectory straightness calculation using the Haversine:

    hashtag
    Sample Windowing

    To predict 96 data points (output) using the preceding 48 data points (input) in a trajectory time series, we create a sliding window. First, we select the initial 48 data points as the input sequence and the subsequent 96 as the output sequence. We then slide the window forward by one step and repeat the process. This continues until the end of the sequence, helping our model capture temporal dependencies and patterns in the data.

    Our training strategy uses the sliding window technique, requiring unique weights for each sample. Sliding Windows (SW) transforms time series data into an appropriate format for machine learning. They generate overlapping windows with a fixed number of consecutive points by sliding the window one step at a time through the series.

    In this project, the input data includes four features: Longitude, Latitude, COG (Course over Ground), and SOG (Speed over Ground), while the output data includes only Longitude and Latitude. To enhance the model's learning, we need to normalize the data through three main steps.

    First, normalize Longitude, Latitude, COG, and SOG to the [0, 1] range using domain-specific parameters. This ensures the model performs well in Atlantic Canada waters by restricting the geographical scope of the AIS data and maintaining a similar scale for all features.

    Second, the input and output data are standardized by subtracting the mean and dividing by the standard deviation. This centers the data around zero and scales it by its variance, preventing vanishing gradients during training.

    Finally, another zero-one normalization is applied to scale the data to the [0, 1] range, aligning it with the expected range for many neural network activation functions.

    Denormalizing Y output to the original scale of the data:

    Denormalizing X output to the original scale of the data:

    machine-learningWe have successfully prepared the data for our machine-learning task. With the data ready, it's time for the modeling phase. Next, we will create, train, and evaluate a machine-learning model to forecast vessel trajectories using the processed dataset. Let's explore how our model performs in Atlantic Canada!

    hashtag
    Gated Recurrent Unit AutoEncoder

    A GRU Autoencoder is a neural network that compresses and reconstructs sequential data utilizing a Gated Recurrent Unit. GRUs are highly effective at handling time-series data, which are sequential data points captured over time, as they can model intricate temporal dependencies and patterns. To perform time-series forecasting, a GRU Autoencoder can be trained on a historical time-series dataset to discern patterns and trends, subsequently compressing a sequence of future data points into a lower-dimensional representation that can be decoded to generate a forecast of the upcoming data points. With this in mind, we will begin by constructing a model architecture composed of two GRU layers with 64 units each, taking input of shape (48, 4) and (96, 4), respectively, followed by a dense layer with 2 units.

    hashtag
    Custom Loss Function

    hashtag
    Model Summary

    hashtag
    Training Callbacks

    The following function lists callbacks used during the model training process. Callbacks are utilities at specific points during training to monitor progress or take actions based on the model's performance. The function pre-define the parameters and behavior of these callbacks:

    • WandbMetricsLogger: This callback logs the training and validation metrics for visualization and monitoring on the Weights & Biases (W&B) platform. This can be useful for tracking the training progress but may introduce additional overhead due to the logging process. You can remove this callback if you don't need to use W&B or want to reduce the overhead.

    • TerminateOnNaN: This callback terminates training if the loss becomes NaN (Not a Number) during the training process. It helps to stop the training process early when the model diverges and encounters an unstable state.

    WandbMetricsLogger is the most computationally costly among these callbacks due to the logging process. You can remove this callback if you don't need to use Weights & Biases for monitoring or want to reduce overhead. The other callbacks help optimize the training process and are less computationally demanding. It's important to note that the Weights & Biases (W&B) platform is also used in other parts of the code. If you decide to remove the WandbMetricsLogger callback, please ensure that you also remove any other references to W&B in the code to avoid potential issues. If you choose to use W&B for monitoring and logging, you must register and log in to the . During the execution of the code, you'll be prompted for an authentication key to connect your script to your W&B account. This key can be obtained from your W&B account settings. Once you have the key, you can use it to enable W&B's monitoring and logging features provided by W&B.

    hashtag
    Model Training

    hashtag
    Model Evaluation

    hashtag
    Hyperparameters Tuning

    In this step, we define a function called model_placeholder that uses the Keras Tuner to create a model with tunable hyperparameters. The function takes a hyperparameter object as input, which defines the search space for the hyperparameters of interest. Specifically, we are searching for the best number of units in the encoder and decoder GRU layers and the optimal learning rate for the AdamW optimizer. The model_placeholder function constructs a GRU-AutoEncoder model with these tunable hyperparameters and compiles the model using the Mean Absolute Error (MAE) as the loss function. Keras Tuner will use this model during the hyperparameter optimization process to find the best combination of hyperparameters that minimizes the validation loss at the expanse of long computing time.

    Helper for saving the training history:

    Helper for restoring the training history:

    Defining the model to be optimized:

    HyperOpt Objective Function:

    hashtag
    Search for Best Model

    Swiping the project folder for other pre-trained weights shared with this tutorial:

    hashtag
    Evaluating Best Model

    hashtag
    Model Explainability

    hashtag
    Permutation Feature Importance (PFI)

    Deep learning models, although powerful, are often criticized for their lack of explainability, making it difficult to comprehend their decision-making process and raising concerns about trust and reliability. To address this issue, we can use techniques like the PFI method, a simple, model-agnostic approach that helps visualize the importance of features in deep learning models. This method works by shuffling individual feature values in the dataset and observing the impact on the model's performance. By measuring the change in a designated metric when each feature's values are randomly permuted, we can infer the importance of that specific feature. The idea is that if a feature is crucial for the model's performance, shuffling its values should lead to a significant shift in performance; otherwise if a feature has little impact, its value permutation should result in a minor change. Applying the permutation feature importance method to the best model, obtained after hyperparameter tuning, can give us a more transparent understanding of how the model makes its decisions.

    hashtag
    Sensitivity Analysis

    Permutation feature importance has some limitations, such as assuming features are independent and producing biased results when features are highly correlated. It also doesn't provide detailed explanations for individual data points. An alternative is sensitivity analysis, which studies how input features affect model predictions. By perturbing each input feature individually and observing the prediction changes, we can understand which features significantly impact the model's output. This approach offers insights into the model's decision-making process and helps identify influential features. However, it does not account for feature interactions and can be computationally expensive for many features or perturbation steps.

    hashtag
    UMAP: Uniform Manifold Approximation and Projection

    UMAP is a nonlinear dimensionality reduction technique that visualizes high-dimensional data in a lower-dimensional space, preserving the local and global structure. In trajectory forecasting, UMAP can project high-dimensional model representations into 2D or 3D to clarify the relationships between input features and outputs. Unlike sensitivity analysis, which measures prediction changes due to input feature perturbations, UMAP reveals data structure without perturbations. It also differs from feature permutation, which evaluates feature importance by shuffling values and assessing model performance changes. UMAP focuses on visualizing intrinsic data structures and relationships.

    hashtag
    Final Considerations

    GRUs can effectively forecast vessel trajectories but have notable downsides. A primary limitation is their struggle with long-term dependencies due to the vanishing gradient problem, causing the loss of relevant information from earlier time steps. This makes capturing long-term patterns in vessel trajectories challenging. Additionally, GRUs are computationally expensive with large datasets and long sequences, resulting in longer training times and higher memory use. While outperforming basic RNNs, they may not always surpass advanced architectures like LSTMs or Transformer models. Furthermore, the interpretability of GRU-based models is a challenge, which can hinder their adoption in safety-critical applications like vessel trajectory forecasting.

    seq2seq in PyTorch

    Sequence to Sequence using Torch

    Vessel trajectories are a type of geospatial temporal data derived from AIS (Automatic Identification System) signals. In this tutorial, we will go over the most common Machine Learning Library to process and model AIS trajectory data.

    We will begin with PyTorch, a widely used deep learning library designed for building and training neural networks. Specifically, we will implement a recurrent neural network using LSTM (Long Short-Term Memory) to model sequential patterns in vessel movements.

    We will utilize AISdb, a dedicated framework for querying, filtering, and preprocessing vessel trajectory data, to streamline data preparation for machine learning workflows.

    hashtag
    Setting Up Our Tools

    First, let's import the libraries we'll be using throughout this tutorial. Our main tools will be NumPy and PyTorch, along with a few other standard libraries for data handling, model building, and visualization.

    • pandas, numpy: for handling tables and arrays

    • torch: for building and training deep learning models

    • sklearn

    Assuming you have the database ready, you can replace the file path and establish a connection.

    We have processed a containing open-source AIS data from Marine Cadastre, covering January to March near Maine, United States.

    To generate the query using AISdb, we use the function. All you have to change here is the DB_CONNECTION , START_DATE, END_DATE and the bounding coordinates.

    Sample coordinates look like this on the map:

    We use pyproj for the metric projection of the latitude and longitude values. Learn more .

    hashtag
    Preprocessing

    We follow the listed steps to prepross the queried trajectory data:

    • Remove pings wrt to speed

    • encoding of tracks given a threshold

    • interpolation according to time (5 mins here)

    The steps above are wrapped into the function defined as:

    Next, we process all vessel tracks and split them into training and test sets, which are used for model training and evaluation.

    hashtag
    Create Sequences

    For geospatial-temporal data, we typically use a sliding window approach, where each trajectory is segmented into input sequences of length X to predict the next Y steps. In this tutorial, we set X = 80 and Y = 2.

    We then save all this data as well as the scalers (we'll use this towards the end in evaluation)

    hashtag
    Load Data

    Now we can load the data and start experimenting with it. The same data can also be reused across different models we want to explore.

    hashtag
    Machine Learning Model - Long Short Term Memory (LSTM)

    We use an attention-based encoder–decoder LSTM model for trajectory prediction. The model has two layers and incorporates teacher forcing, a strategy where the decoder is occasionally fed the ground-truth values during training. This helps stabilize learning and prevents the model from drifting too far when making multi-step predictions.

    hashtag
    Auxiliary Loss Components

    Two auxiliary functions are introduced to augment the original MSE loss. These additional terms are designed to better preserve the physical consistency and structural shape of the predicted trajectory.

    hashtag
    Model Training

    Once the model is defined, the next step is to train it on our prepared dataset. Training involves iteratively feeding input sequences to the model, comparing its predictions against the ground truth, and updating the weights to reduce the error.

    In our case, the loss function combines:

    • a data term (based on weighted coordinate errors and auxiliary features), and

    • a smoothness penalty (to encourage realistic vessel movement and reduce jitter in the predicted trajectory).

    hashtag
    Model Evaluation

    Finally, now that our model has been trained we use an evaluation function to check it in the different dataset we had stores earlier, as well as plot it on a map to see how the trajectory predictions look. Note- we dont just rely on the accuracy or training/testing results in numbers. There might be chances when the loss is showing in decimals but the coordinates are way far off. That is why we chose to plot it out on a map as well to check the predictions.

    There are some debugging statements as well to see whether the scaling is right or wrong, the distace error etc. In this Model we have a metric distance error of only 800m.

    hashtag
    Results

    Predicted vs True (lat/lon)

    t
    lon_true
    lon_pred
    lat_true
    lat_pred
    Error (in m)

    Summary (meters) t=0 mean error: 833.31 m mean over horizon: 833.31 m, median: 833.31 m

    AISdb Made Easy: A No-Code Interface

    The AISdb no-code interface offers a complete, visual way to process and analyze AIS data without writing a single line of code. Users can split trajectories by time, encode vessel tracks, discretize locations using H3, and detect vessel stops with adjustable parameters. The interface also supports adding environmental context such as bathymetry and weather layers. With customizable controls for gap duration, time segmentation, encoding distance and speed thresholds, and H3 resolution, the platform streamlines preprocessing while maintaining full analytical flexibility. Whether you’re preparing data for modeling, visualization, or research, every step—from segmentation to environmental enrichment—can be configured directly through an intuitive, click-based workflow. Below is the file you can just download and run locally to access the interface

    hashtag
    Interface

    On the landing page you can just drag and drop your csv with AIS data, the assistant will automatically give you details about the csv, total rows, mean etc. .

    Using Newtonian PINNs

    Traditional Seq2Seq LSTM models have long been the workhorse for trajectory forecasting. They excel at learning temporal dependencies in AIS data, and with attention mechanisms they can capture complex nonlinear patterns over long histories. However, they remain purely statistical. This means the model can generate plausible-looking trajectories from a data perspective but with no guarantee that those predictions respect the underlying physics of vessel motion. In practice, this often manifests as sharp turns, unrealistic accelerations, or trajectories that deviate significantly when the model faces sparse or noisy data.

    The NPINN-based approach directly addresses these shortcomings. By embedding smoothness and kinematic penalties into training, it enforces constraints on velocity and acceleration while still benefiting from the representational power of deep sequence models. Instead of simply fitting residuals between past and future positions, NPINN ensures that predictions evolve in ways consistent with how vessels actually move in the physical world. This leads to more reliable extrapolation, especially in data-scarce regions or unusual navigation scenarios.

    hashtag

    Building a RAG Chatbot

    One of the most potent applications enabled by LLMs is the development of question-answering chatbots. These are applications that can answer questions about specific source information. This tutorial demonstrates how to build a chatbot that can answer questions about AISViz documentation using a technique known as Retrieval-Augmented Generation (RAG).

    hashtag
    Overview

    A typical RAG application has two main components:

    Embedding with traj2vec

    AIS data is messy. It comes as a stream of latitude, longitude, timestamps, and a few metadata fields. By themselves, they’re hard to compare or feed into downstream ML models. What we really want is

    Inspired by word2vec, traj2vec applies the same logic to movement data: instead of predicting the next word, it predicts the next location in a sequence. Just like words gain meaning from context, vessel positions gain meaning from their trajectory history.

    The result: trajectories that “look alike” end up close together in embedding space. For instance, two ferries running parallel routes will embed similarly, while a cargo vessel crossing the Gulf of Mexico will sit far away from a fishing boat looping off the coast.

    hashtag
    Imports

    import aisdb
    import numpy as np
    
    # Assuming you have a database connection and domain set up as described
    with aisdb.SQLiteDBConn(dbpath='your_ais_database.db') as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn,
            start='2023-01-01', end='2023-01-02',  # Example time range
            xmin=-10, xmax=0, ymin=40, ymax=50,    # Example bounding box
            callback=aisdb.database.sqlfcn_callbacks.in_validmmsi_bbox,
        )
    
        simplified_tracks = aisdb.TrackGen(qry.gen_qry(), decimate=True)  # Generate initial tracks 
    
        for segment in simplified_tracks:
            print(f"Simplified track for MMSI: {segment['mmsi']}, Points: {segment['lon'].size}")
    
    import aisdb
    import numpy as np
    
    # Assuming you have a database connection and domain set up as described
    with aisdb.SQLiteDBConn(dbpath='your_ais_database.db') as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn,
            start='2023-01-01', end='2023-01-02',  # Example time range
            xmin=-10, xmax=0, ymin=40, ymax=50,    # Example bounding box
            callback=aisdb.database.sqlfcn_callbacks.in_validmmsi_bbox,
        )
    
        tracks = aisdb.TrackGen(qry.gen_qry(), decimate=False)  # Generate initial tracks
    
        simplified_tracks = []
    
        for track in tracks:
            if track['lon'].size > 2:  # Ensure track has enough points
                # Apply decimation using simplify_linestring_idx
                simplified_indices = aisdb.track_gen.simplify_linestring_idx(
                    track['lon'], track['lat'], precision=0.01  # Example precision
                )
                # Extract simplified track points
                simplified_track = {
                    'mmsi': track['mmsi'],
                    'time': track['time'][simplified_indices],
                    'lon': track['lon'][simplified_indices],
                    'lat': track['lat'][simplified_indices],
                    # Carry over other relevant track attributes as needed
                }
                simplified_tracks.append(simplified_track)
            else:
                simplified_tracks.append(track)  # Keep tracks with few points as is
    
        # Now 'simplified_tracks' contains decimated tracks ready for further analysis
        for segment in simplified_tracks:
            print(f"Simplified track for MMSI: {segment['mmsi']}, Points: {segment['lon'].size}")
    
    https://www.marinelink.com/news/definitive-guide-ais418266arrow-up-right
    http://hdl.handle.net/10222/81160arrow-up-right
    https://www.marinelink.com/news/definitive-guide-ais418266arrow-up-right
    Image from Marinelinkarrow-up-right
    (Amigo et al., 2021)
    Red: Tracks with an average depth greater than 20,000 meters (deepest waters).
    Vessel tracks colored with average depths relative to the bathymetry
    https://github.com/AISViz/AISdb/blob/master/examples/discretize.ipynbarrow-up-right
    https://www.uber.com/en-CA/blog/h3/arrow-up-right
    And the
    directory
    where the GRIB files are stored.
    https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation#heading-Parameterlistingsarrow-up-right
    10 metre U wind componentarrow-up-right
    10 metre V wind componentarrow-up-right
    U-V 100m component wind over Gulf Of St. Lawrance for Aug 2018
    OUTPUT_TIMESTEPS: This parameter sets the number of future time steps the model should predict, known as the prediction horizon. Choosing the right horizon size is critical. Predicting too few timesteps may not serve the application's needs while predicting too many can increase uncertainty and degrade performance. Select a value based on your application's specific requirements and data quality.
  • OUTPUT_VARIABLES: In trajectory forecasting, output variables include predicted spatial coordinates and sometimes other relevant features. Reducing the number of output variables can simplify prediction tasks and decrease model complexity. However, this approach might also lead to a less effective model.

  • ReduceLROnPlateau: This callback reduces the learning rate by a specified factor when the monitored metric has stopped improving for several epochs. It helps fine-tune the model using a lower learning rate when it no longer improves significantly.
  • EarlyStopping: This callback stops the training process early when the monitored metric has not improved for a specified number of epochs. It restores the model's best weights when the training is terminated, preventing overfitting and reducing the training time.

  • ModelCheckpoint: This callback saves the best model (based on the monitored metric) to a file during training.

  • W&B websitearrow-up-right
    AIS Data Cover in the Gulf of St. Lawrence
    : for data splitting and evaluation utilities
  • matplotlib: for visualizing model performance and outputs

  • group data based on mmsi
  • filter out mmsi's with less than 100 points

  • Convert lat lon to x & y on cartesian plane using pyroj

  • Use the sin cos value of cog as its a 360 degree value

  • drop NaN values

  • apply scaling to ensure value are normalized

  • 0

    -61.69744

    -61.70585

    43.22816

    43.22385

    833.31 m

    sample SQLite databasearrow-up-right
    DBQueryarrow-up-right
    herearrow-up-right
    import os
    import aisdb
    import nest_asyncio
    
    from datetime import datetime
    from aisdb.database.dbconn import SQLiteDBConn
    from aisdb import DBConn, DBQuery, DomainFromPoints
    
    nest_asyncio.apply()
    
    # set the path to the data storage directory
    bathymetry_data_dir = "./bathymetry_data/"
    
    # check if the directory exists
    if not os.path.exists(bathymetry_data_dir):
        os.makedirs(bathymetry_data_dir)
    
    # check if the directory is empty
    if os.listdir(bathymetry_data_dir) == []:
        # download the bathymetry data
        bathy = aisdb.webdata.bathymetry.Gebco(data_dir=bathymetry_data_dir)
        bathy.fetch_bathymetry_grid()
    else:
        print("Bathymetry data already exists.")
    def add_color(tracks):
        for track in tracks:
            # Calculate the average coastal distance
            avg_coast_distance = sum(abs(dist) for dist in track['coast_distance']) / len(track['coast_distance'])
            
            # Determine the color based on the average coastal distance
            if avg_coast_distance <= 100:
                track['color'] = "yellow"
            elif avg_coast_distance <= 1000:
                track['color'] = "orange"
            elif avg_coast_distance <= 20000:
                track['color'] = "pink"
            else:
                track['color'] = "red"
            
            yield track
    dbpath = 'YOUR_DATABASE.db' # define the path to your database
    end_time = datetime.strptime("2018-01-02 00:00:00", '%Y-%m-%d %H:%M:%S')
    start_time = datetime.strptime("2018-01-01 00:00:00", '%Y-%m-%d %H:%M:%S')
    domain = DomainFromPoints(points=[(-63.6, 44.6)], radial_distances=[500000])
    
    with SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = DBQuery(
            dbconn=dbconn, start=start_time, end=end_time,
            xmin=domain.boundary['xmin'], xmax=domain.boundary['xmax'],
            ymin=domain.boundary['ymin'], ymax=domain.boundary['ymax'],
            callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi,
        )
        tracks = aisdb.track_gen.TrackGen(qry.gen_qry(), decimate=False)
    
        # Merge the tracks with the raster data
        raster_path = "./bathymetry_data/gebco_2022_n90.0_s0.0_w-90.0_e0.0.tif"
        raster = aisdb.webdata.load_raster.RasterFile(raster_path)
        tracks_raster = raster.merge_tracks(tracks, new_track_key="coast_distance")
    
        # Add color to the tracks
        tracks_colored = add_color(tracks_raster)
    
        if __name__ == '__main__':
            aisdb.web_interface.visualize(
                tracks_colored,
                visualearth=True,
                open_browser=True,
            )
    import aisdb
    from aisdb import DBQuery
    from aisdb.database.dbconn import PostgresDBConn
    from datetime import datetime, timedelta
    from aisdb.discretize.h3 import Discretizer  # main import to convert lat/lon to H3 indexes
    
    # >>> PostgreSQL connection details (replace placeholders or use environment variables) <<<
    db_user = '<>'            # PostgreSQL username
    db_dbname = '<>'          # PostgreSQL database/schema name
    db_password = '<>'        # PostgreSQL password
    db_hostaddr = '127.0.0.1' # PostgreSQL host (localhost shown)
    
    # Create a database connection handle for AISDB to use
    dbconn = PostgresDBConn(
        port=5555,            # PostgreSQL port (5432 is default; 5555 here is just an example)
        user=db_user,         # username for authentication
        dbname=db_dbname,     # database/schema to connect to
        host=db_hostaddr,     # host address or DNS name
        password=db_password, # password for authentication
    )
    
    # ------------------------------
    # Define the spatial and temporal query window
    # Note: bbox is [xmin, ymin, xmax, ymax] in lon/lat; variables below help readability
    xmin, ymin, xmax, ymax = -70, 45, -58, 53
    gulf_bbox = [xmin, xmax, ymin, ymax]  # optional helper; not used directly below
    
    start_time = datetime(2023, 8, 1)     # query start (inclusive)
    end_time   = datetime(2023, 8, 2)     # query end (exclusive or inclusive per DB settings)
    
    # Build a query that streams AIS rows in the time window and bounding box
    qry = DBQuery(
        dbconn=dbconn,
        start=start_time, end=end_time,
        xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax,
        # Callback filters rows by time, bbox, and ensures MMSI validity (helps remove junk)
        callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi
    )
    
    # Prepare containers/generators for streamed processing (memory-efficient)
    ais_tracks = []          # placeholder list if you want to collect tracks (unused below)
    rowgen = qry.gen_qry()   # generator that yields raw AIS rows from the database
    
    # ------------------------------
    # Instantiate the H3 Discretizer at a chosen resolution
    # Resolution 6 ≈ regional scale hexagons; increase for finer grids, decrease for coarser grids
    # Note: variable name 'descritizer' is kept to match the original snippet (typo is harmless)
    descritizer = Discretizer(resolution=6)
    
    # Build tracks from rows; decimate=True reduces oversampling and speeds up processing
    tracks = aisdb.track_gen.TrackGen(rowgen, decimate=True)
    
    # Optionally split long tracks into time-bounded segments (e.g., 4-week chunks)
    # Useful for chunked processing or time-based aggregation; not used further in this snippet
    tracks_segment = aisdb.track_gen.split_timedelta(
        tracks,
        timedelta(weeks=4)
    )
    
    # Discretize each track: adds an H3 index array aligned with lat/lon points for that track
    # Each yielded track will have keys like 'lat', 'lon', and 'h3_index'
    tracks_with_indexes = descritizer.yield_tracks_discretized_by_indexes(tracks)
    
    # Example (optional) usage:
    # for t in tracks_with_indexes:
    #     # Access the first point's H3 index for this track
    #     print(t['mmsi'], t['timestamp'][0], t['lat'][0], t['lon'][0], t['h3_index'][0])
    #     break
    
    # Output: H3 Index for lat 50.003334045410156, lon -66.76000213623047: 860e4d31fffffff
    from aisdb.weather.data_store import WeatherDataStore # for weather
    
    # ...some code before...
    with aisdb.SQLiteDBConn(dbpath=dbpath) as dbconn:
        qry = aisdb.DBQuery(
            dbconn=dbconn,
            start=start_time,
            end=end_time,
            callback=aisdb.database.sqlfcn_callbacks.in_timerange_validmmsi,
        )
        rowgen = qry.gen_qry()
        
        # Convert queried rows to vessel trajectories
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
        
        # Mention the short_names for the required weather data from the grib file
        weather_data_store = WeatherDataStore(short_names = ['10u','10v','tp'], start = start_time,end =  end_time,weather_data_path = ".",download_from_cds = False ,area = [-70, 45, -58, 53])
    
        tracks = weather_data_store.yield_tracks_with_weather(tracks)
        
        for t in tracks:
            print(f"'u-component' 10m wind for:\nlat: {t['lat'][0]} \nlon: {t['lon'][0]} \ntime: {t['time'][0]} \nis {t['weather_data']['u10'][0]} m/s")
            break
        
        weather_data_store.close()
    'u-component' 10m wind for: 
    lat: 50.003334045410156 
    lon: -66.76000213623047 
    time: 1690858823 
    is 1.9680767059326172 m/s
    import aisdb
    import nest_asyncio
    from aisdb import DBQuery
    from aisdb.database.dbconn import PostgresDBConn
    from datetime import datetime
    from PIL import ImageFile
    from aisdb.weather.data_store import WeatherDataStore # for weather
    
    nest_asyncio.apply()
    ImageFile.LOAD_TRUNCATED_IMAGES = True
    # >>> PostgreSQL Information <<<
    db_user=''            # DB User
    db_dbname='aisviz'         # DB Schema
    db_password=''    # DB Password
    db_hostaddr='127.0.0.1'    # DB Host address
    
    dbconn = PostgresDBConn(
        port=5555,             # PostgreSQL port
        user=db_user,          # PostgreSQL username
        dbname=db_dbname,      # PostgreSQL database
        host=db_hostaddr,      # PostgreSQL address
        password=db_password,  # PostgreSQL password
    )
    xmin, ymin, xmax, ymax = -70, 45, -58, 53
    gulf_bbox = [xmin, xmax, ymin, ymax]
    start_time = datetime(2023, 8, 1)
    end_time = datetime(2023, 8, 30)
    
    qry = DBQuery(
        dbconn=dbconn,
        start=start_time, end=end_time,
        xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax,
        callback=aisdb.database.sqlfcn_callbacks.in_time_bbox_validmmsi
    )
    
    ais_tracks = []
    rowgen = qry.gen_qry()
    tracks = aisdb.track_gen.TrackGen(rowgen, decimate=True")
    weather_data_store = WeatherDataStore(
        short_names=['10u', '10v', 'tp'],       # U & V wind components, total precipitation
        start=start_time,                      # e.g., datetime(2023, 8, 1)
        end=end_time,                          # e.g., datetime(2023, 8, 30)
        weather_data_path=".",                 # Local folder to store downloaded GRIBs
        download_from_cds=True,                # Enable download from CDS
        area=[-70, 45, -58, 53]                # [west, north, east, south] in degrees
    )
    tracks_with_weather = weather_data_store.yield_tracks_with_weather(tracks)
        
    for track in tracks_with_weather :
        print(f"'u-component' 10m wind for:\nlat: {track['lat'][0]} \nlon: {track['lon'][0]} \ntime: {track['time'][0]} \nis {track['weather_data']['u10'][0]} m/s")
        break
        
    weather_data_store.close() # gracefully closes the opened GRIB file.
    def qry_database(dbname, start_time, stop_time):
        d_threshold = 200000  # max distance (in meters) between two messages before assuming separate tracks
        s_threshold = 50  # max speed (in knots) between two AIS messages before splitting tracks
        t_threshold = timedelta(hours=24)  # max time (in hours) between messages of a track
        try:
            with aisdb.DBConn() as dbconn:
                tracks = aisdb.TrackGen(
                    aisdb.DBQuery(
                        dbconn=dbconn, dbpath=os.path.join(ROOT, PATH, dbname),
                        callback=aisdb.database.sql_query_strings.in_timerange,
                        start=start_time, end=stop_time).gen_qry(),
                        decimate=False)  # trajectory compression
                tracks = aisdb.split_timedelta(tracks, t_threshold)  # split trajectories by time without AIS message transmission
                tracks = aisdb.encode_greatcircledistance(tracks, distance_threshold=d_threshold, speed_threshold=s_threshold)
                tracks = aisdb.interp_time(tracks, step=timedelta(minutes=5))  # interpolate every n-minutes
                # tracks = vessel_info(tracks, dbconn=dbconn)  # scrapes vessel metadata
                return list(tracks)  # list of segmented pre-processed tracks
        except SyntaxError as e: return []  # no results for query
    def get_tracks(dbname, start_ddmmyyyy, stop_ddmmyyyy):
        stop_time = datetime.strptime(stop_ddmmyyyy, "%d%m%Y")
        start_time = datetime.strptime(start_ddmmyyyy, "%d%m%Y")
        # returns a list with all tracks from AISdb
        return qry_database(dbname, start_time, stop_time)
    def batch_tracks(dbname, start_ddmmyyyy, stop_ddmmyyyy, hours2batch):
        stop_time = datetime.strptime(stop_ddmmyyyy, "%d%m%Y")
        start_time = datetime.strptime(start_ddmmyyyy, "%d%m%Y")
        # yields a list of results every delta_time iterations
        delta_time = timedelta(hours=hours2batch)
        anchor_time, next_time = start_time, start_time + delta_time
        while next_time < stop_time:
            yield qry_database(dbname, anchor_time, next_time)
            anchor_time = next_time
            next_time += delta_time
        # yields a list of final results (if any)
        yield qry_database(dbname, anchor_time, stop_time)
    # Create the map with specific latitude and longitude limits and a Mercator projection
    m = Basemap(llcrnrlat=42, urcrnrlat=52, llcrnrlon=-70, urcrnrlon=-50, projection="merc", resolution="h")
    
    # Draw state, country, coastline borders, and counties
    m.drawstates(0.5)
    m.drawcountries(0.5)
    m.drawcoastlines(0.5)
    m.drawcounties(color="gray", linewidth=0.5)
    
    # Fill continents and oceans
    m.fillcontinents(color="tan", lake_color="#91A3B0")
    m.drawmapboundary(fill_color="#91A3B0")
    
    coordinates = [
        (51.26912, -57.53759), (48.92733, -58.87786),
        (47.49307, -59.41325), (42.54760, -62.17624),
        (43.21702, -60.49943), (44.14955, -60.59600),
        (45.42599, -59.76398), (46.99134, -60.02403)]
    
    # Draw 100km-radius circles
    for lat, lon in coordinates:
        radius_in_degrees = 100 / (111.32 * np.cos(np.deg2rad(lat)))
        m.tissot(lon, lat, radius_in_degrees, 100, facecolor="r", edgecolor="k", alpha=0.5)
    
    # Add text annotation with an arrow pointing to the circle
    plt.annotate("AIS Coverage", xy=m(lon, lat), xytext=(40, -40),
                 textcoords="offset points", ha="left", va="bottom", fontweight="bold",
                 arrowprops=dict(arrowstyle="->", color="k", alpha=0.7, lw=2.5))
    
    # Add labels
    ocean_labels = {
        "Atlantic Ocean": [(-59, 44), 16],
        "Gulf of\nMaine": [(-67, 44.5), 12],
        "Gulf of St. Lawrence": [(-64.5, 48.5), 11],
    }
    
    for label, (coords, fontsize) in ocean_labels.items():
        plt.annotate(label, xy=m(*coords), xytext=(6, 6), textcoords="offset points",
                     fontsize=fontsize, color="#DBE2E9", fontweight="bold")
    
    # Add a scale in kilometers
    m.drawmapscale(-67.5, 42.7, -67.5, 42.7, 500, barstyle="fancy", fontsize=8, units="km", labelstyle="simple")
    
    # Set the map title
    _ = plt.title("100km-AIS radius-coverage on Atlantic Canada", fontweight="bold")
    # The circle diameter is 200km, and it does not match the km scale (approximation)
    land_polygons = gpd.read_file(os.path.join(ROOT, SHAPES, "ne_50m_land.shp"))
    def is_on_land(lat, lon, land_polygons):
        return land_polygons.contains(Point(lon, lat)).any()
    def is_track_on_land(track, land_polygons):
        for lat, lon in zip(track["lat"], track["lon"]):
            if is_on_land(lat, lon, land_polygons):
                return True
        return False
    def process_mmsi(item, polygons):
        mmsi, tracks = item
        filtered_tracks = [t for t in tracks if not is_track_on_land(t, polygons)]
        return mmsi, filtered_tracks, len(tracks)
    def process_voyages(voyages, land_polygons):
    
        # Tracking progress with TQDM
        def process_mmsi_callback(result, progress_bar):
            mmsi, filtered_tracks, _ = result
            voyages[mmsi] = filtered_tracks
            progress_bar.update(1)
    
        # Initialize the progress bar with the total number of MMSIs
        progress_bar = tqdm(total=len(voyages), desc="MMSIs processed")
    
        with ThreadPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor:
            # Submit all MMSIs for processing
            futures = {executor.submit(process_mmsi, item, land_polygons): item for item in voyages.items()}
    
            # Retrieve the results as they become available and update the Voyages dictionary
            for future in as_completed(futures):
                result = future.result()
                process_mmsi_callback(result, progress_bar)
    
        # Close the progress bar after processing the complete
        progress_bar.close()
        return voyages
    file_name = "curated-ais.pkl"
    full_path = os.path.join(ROOT, ESRF, file_name)
    if not os.path.exists(full_path):
        voyages = process_voyages(voyages, land_polygons)
        pkl.dump(voyages, open(full_path, "wb"))
    else: voyages = pkl.load(open(full_path, "rb"))
    voyages_counts = {k: len(voyages[k]) for k in voyages.keys()}
    def plot_voyage_segments_distribution(voyages_counts, bar_color="#ba1644"):
        data = pd.DataFrame({"Segments": list(voyages_counts.values())})
        return  alt.Chart(data).mark_bar(color=bar_color).encode(
                    alt.X("Segments:Q", bin=alt.Bin(maxbins=90), title="Segments"),
                    alt.Y("count(Segments):Q", title="Count", scale=alt.Scale(type="log")))\
                .properties(title="Distribution of Voyage Segments", width=600, height=400)\
                .configure_axisX(titleFontSize=16).configure_axisY(titleFontSize=16)\
                .configure_title(fontSize=18).configure_view(strokeOpacity=0)
    alt.data_transformers.enable("default", max_rows=None)
    plot_voyage_segments_distribution(voyages_counts).display()
    long_term_voyages, short_term_voyages = [], []
    # Separating voyages
    for k in voyages_counts:
        if voyages_counts[k] < 30:
            short_term_voyages.append(k)
        else: long_term_voyages.append(k)
    # Shuffling for random distribution
    random.shuffle(short_term_voyages)
    random.shuffle(long_term_voyages)
    train_voyage, test_voyage = {}, {}
    # Iterate over short-term voyages:
    for i, k in enumerate(short_term_voyages):
        if i < int(0.8 * len(short_term_voyages)):
            train_voyage[k] = voyages[k]
        else: test_voyage[k] = voyages[k]
    # Iterate over long-term voyages:
    for i, k in enumerate(long_term_voyages):
        if i < int(0.8 * len(long_term_voyages)):
            train_voyage[k] = voyages[k]
        else: test_voyage[k] = voyages[k]
    def plot_voyage_length_distribution(data, title, bar_color, min_time=144, force_print=True):
        total_time = []
        for key in data.keys():
            for track in data[key]:
                if len(track["time"]) > min_time or force_print:
                    total_time.append(len(track["time"]))
        plot_data = pd.DataFrame({'Length': total_time})
        chart = alt.Chart(plot_data).mark_bar(color=bar_color).encode(
            alt.Y("count(Length):Q", title="Count", scale=alt.Scale(type="symlog")),
            alt.X("Length:Q", bin=alt.Bin(maxbins=90), title="Length")
        ).properties(title=title, width=600, height=400)\
        .configure_axisX(titleFontSize=16).configure_axisY(titleFontSize=16)\
        .configure_title(fontSize=18).configure_view(strokeOpacity=0)
        print("\n\n")
        return chart
    display(plot_voyage_length_distribution(train_voyage, "TRAINING: Distribution of Voyage Length", "#287561"))
    display(plot_voyage_length_distribution(test_voyage, "TEST: Distribution of Voyage Length", "#3e57ab"))
    INPUT_TIMESTEPS = 48  # 4 hours * 12 AIS messages/h
    INPUT_VARIABLES = 4  # Longitude, Latitude, COG, and SOG
    OUTPUT_TIMESTEPS = 96  # 8 hours * 12 AIS messages/h
    OUTPUT_VARIABLES = 2  # Longitude and Latitude
    NUM_WORKERS = multiprocessing.cpu_count()
    INPUT_VARIABLES *= 2  # Double the features with deltas
    def filter_and_transform_voyages(voyages):
        filtered_voyages = {}
        for k, v in voyages.items():
            voyages_track = []
            for voyage in v:
                if len(voyage["time"]) > (INPUT_TIMESTEPS + OUTPUT_TIMESTEPS):
                    mtx = np.vstack([voyage["lon"], voyage["lat"],
                                     voyage["cog"], voyage["sog"]]).T
                    # Compute deltas
                    deltas = np.diff(mtx, axis=0)
                    # Add zeros at the first row for deltas
                    deltas = np.vstack([np.zeros(deltas.shape[1]), deltas])
                    # Concatenate the original matrix with the deltas matrix
                    mtx = np.hstack([mtx, deltas])
                    voyages_track.append(mtx)
            if len(voyages_track) > 0:
                filtered_voyages[k] = voyages_track
        return filtered_voyages
    # Checking how the data behaves for the previously set hyperparameters
    display(plot_voyage_length_distribution(train_voyage, "TRAINING: Distribution of Voyage Length", "#287561",
                                            min_time=INPUT_TIMESTEPS + OUTPUT_TIMESTEPS, force_print=False))
    display(plot_voyage_length_distribution(test_voyage, "TEST: Distribution of Voyage Length", "#3e57ab",
                                            min_time=INPUT_TIMESTEPS + OUTPUT_TIMESTEPS, force_print=False))
    # Filter and transform train and test voyages and prepare for training
    train_voyage = filter_and_transform_voyages(train_voyage)
    test_voyage = filter_and_transform_voyages(test_voyage)
    def print_voyage_statistics(header, voyage_dict):
        total_time = 0
        for mmsi, trajectories in voyage_dict.items():
            for trajectory in trajectories:
                total_time += trajectory.shape[0]
        print(f"{header}")
        print(f"Hours of sequential data: {total_time // 12}.")
        print(f"Number of unique MMSIs: {len(voyage_dict)}.", end=" \n\n")
        return total_time
    time_test = print_voyage_statistics("[TEST DATA]", test_voyage)
    time_train = print_voyage_statistics("[TRAINING DATA]", train_voyage)
    # We remained with a distribution of data that still resembles the 80-20 ratio
    print(f"Training hourly-rate: {(time_train * 100) / (time_train + time_test)}%")
    print(f"Test hourly-rate: {(time_test * 100) / (time_train + time_test)}%")
    def haversine_distance(lon_1, lat_1, lon_2, lat_2):
        lon_1, lat_1, lon_2, lat_2 = map(np.radians, [lon_1, lat_1, lon_2, lat_2])  # convert latitude and longitude to radians
        a = np.sin((lat_2 - lat_1) / 2) ** 2 + np.cos(lat_1) * np.cos(lat_2) * np.sin((lon_2 - lon_1) / 2) ** 2
        return (2 * np.arcsin(np.sqrt(a))) * 6371000  # R: 6,371,000 meters
    def trajectory_straightness(x):
        start_point, end_point = x[0, :2], x[-1, :2]
        x_coordinates, y_coordinates = x[:-1, 0], x[:-1, 1]
        x_coordinates_next, y_coordinates_next = x[1:, 0], x[1:, 1]
        consecutive_distances = np.array(haversine_distance(x_coordinates, y_coordinates, x_coordinates_next, y_coordinates_next))
        straight_line_distance = np.array(haversine_distance(start_point[0], start_point[1], end_point[0], end_point[1]))
        result = straight_line_distance / np.sum(consecutive_distances)
        return result if not np.isnan(result) else 1
    def process_voyage(voyage, mmsi, max_size, overlap_size=1):
        straightness_ratios, mmsis, x, y = [], [], [], []
    
        for j in range(0, voyage.shape[0] - max_size, 1):
            x_sample = voyage[(0 + j):(INPUT_TIMESTEPS + j)]
            y_sample = voyage[(INPUT_TIMESTEPS + j - overlap_size):(max_size + j), 0:OUTPUT_VARIABLES]
            straightness = trajectory_straightness(x_sample)
            straightness_ratios.append(straightness)
            x.append(x_sample.T)
            y.append(y_sample.T)
            mmsis.append(mmsi)
    
        return straightness_ratios, mmsis, x, y
    def process_data(voyages):
        max_size = INPUT_TIMESTEPS + OUTPUT_TIMESTEPS
    
        # Callback function to update tqdm progress bar
        def process_voyage_callback(result, pbar):
            pbar.update(1)
            return result
    
        with Pool(NUM_WORKERS) as pool, tqdm(total=sum(len(v) for v in voyages.values()), desc="Voyages") as pbar:
            results = []
    
            # Submit tasks to the pool and store the results
            for mmsi in voyages:
                for voyage in voyages[mmsi]:
                    callback = partial(process_voyage_callback, pbar=pbar)
                    results.append(pool.apply_async(process_voyage, (voyage, mmsi, max_size), callback=callback))
    
            pool.close()
            pool.join()
    
            # Gather the results
            straightness_ratios, mmsis, x, y = [], [], [], []
            for result in results:
                s_ratios, s_mmsis, s_x, s_y = result.get()
                straightness_ratios.extend(s_ratios)
                mmsis.extend(s_mmsis)
                x.extend(s_x)
                y.extend(s_y)
    
        # Process the results
        x, y = np.stack(x), np.stack(y)
        x, y = np.transpose(x, (0, 2, 1)), np.transpose(y, (0, 2, 1))
        straightness_ratios = np.array(straightness_ratios)
        min_straightness, max_straightness = np.min(straightness_ratios), np.max(straightness_ratios)
        scaled_straightness_ratios = (straightness_ratios - min_straightness) / (max_straightness - min_straightness)
        scaled_straightness_ratios = 1. - scaled_straightness_ratios
    
        print(f"Final number of samples = {len(x)}", end="\n\n")
        return mmsis, x, y, scaled_straightness_ratios
    mmsi_train, x_train, y_train, straightness_ratios = process_data(train_voyage)
    mmsi_test, x_test, y_test, _ = process_data(test_voyage)
    def normalize_dataset(x_train, x_test, y_train,
                          lat_min=42, lat_max=52, lon_min=-70, lon_max=-50, max_sog=50):
    
        def normalize(arr, min_val, max_val):
            return (arr - min_val) / (max_val - min_val)
    
        # Initial normalization
        x_train[:, :, :2] = normalize(x_train[:, :, :2], np.array([lon_min, lat_min]), np.array([lon_max, lat_max]))
        y_train[:, :, :2] = normalize(y_train[:, :, :2], np.array([lon_min, lat_min]), np.array([lon_max, lat_max]))
        x_test[:, :, :2] = normalize(x_test[:, :, :2], np.array([lon_min, lat_min]), np.array([lon_max, lat_max]))
    
        x_train[:, :, 2:4] = x_train[:, :, 2:4] / np.array([360, max_sog])
        x_test[:, :, 2:4] = x_test[:, :, 2:4] / np.array([360, max_sog])
    
        # Standardize X and Y
        x_mean, x_std = np.mean(x_train, axis=(0, 1)), np.std(x_train, axis=(0, 1))
        y_mean, y_std = np.mean(y_train, axis=(0, 1)), np.std(y_train, axis=(0, 1))
    
        x_train = (x_train - x_mean) / x_std
        y_train = (y_train - y_mean) / y_std
        x_test = (x_test - x_mean) / x_std
    
        # Final zero-one normalization
        x_min, x_max = np.min(x_train, axis=(0, 1)), np.max(x_train, axis=(0, 1))
        y_min, y_max = np.min(y_train, axis=(0, 1)), np.max(y_train, axis=(0, 1))
    
        x_train = (x_train - x_min) / (x_max - x_min)
        y_train = (y_train - y_min) / (y_max - y_min)
        x_test = (x_test - x_min) / (x_max - x_min)
    
        return x_train, x_test, y_train, y_mean, y_std, y_min, y_max, x_mean, x_std, x_min, x_max
        return x_train, x_test, y_train, y_mean, y_std, y_min, y_max, x_mean, x_std, x_min, x_max
    
    x_train, x_test, y_train, y_mean, y_std, y_min, y_max, x_mean, x_std, x_min, x_max = normalize_dataset(x_train, x_test, y_train)
    def denormalize_y(y_data, y_mean, y_std, y_min, y_max,
                      lat_min=42, lat_max=52, lon_min=-70, lon_max=-50):
        y_data = y_data * (y_max - y_min) + y_min  # reverse zero-one normalization
        y_data = y_data * y_std + y_mean  # reverse standardization
        # Reverse initial normalization for longitude and latitude
        y_data[:, :, 0] = y_data[:, :, 0] * (lon_max - lon_min) + lon_min
        y_data[:, :, 1] = y_data[:, :, 1] * (lat_max - lat_min) + lat_min
        return y_data
    def denormalize_x(x_data, x_mean, x_std, x_min, x_max,
                      lat_min=42, lat_max=52, lon_min=-70, lon_max=-50):
        x_data = x_data * (x_max - x_min) + x_min  # reverse zero-one normalization
        x_data = x_data * x_std + x_mean  # reverse standardization
        # Reverse initial normalization for longitude and latitude
        x_data[:, :, 0] = x_data[:, :, 0] * (lon_max - lon_min) + lon_min
        x_data[:, :, 1] = x_data[:, :, 1] * (lat_max - lat_min) + lat_min
        return x_data
    tf.keras.backend.clear_session()  # Clear the Keras session to prevent potential conflicts
    _ = wandb.login(force=True)  # Log in to Weights & Biases
    class ProbabilisticTeacherForcing(Layer):
        def __init__(self, **kwargs):
            super(ProbabilisticTeacherForcing, self).__init__(**kwargs)
    
        def call(self, inputs):
            decoder_gt_input, decoder_output, mixing_prob = inputs
            mixing_prob = tf.expand_dims(mixing_prob, axis=-1)  # Add an extra dimension for broadcasting
            mixing_prob = tf.broadcast_to(mixing_prob, tf.shape(decoder_gt_input))  # Broadcast to match the shape
            return tf.where(tf.random.uniform(tf.shape(decoder_gt_input)) < mixing_prob, decoder_gt_input, decoder_output)
    def build_model(rnn_unit="GRU", hidden_size=64):
        encoder_input = Input(shape=(INPUT_TIMESTEPS, INPUT_VARIABLES), name="Encoder_Input")
        decoder_gt_input = Input(shape=((OUTPUT_TIMESTEPS - 1), OUTPUT_VARIABLES), name="Decoder-GT-Input")
        mixing_prob_input = Input(shape=(1,), name="Mixing_Probability")
    
        # Encoder
        encoder_gru = eval(rnn_unit)(hidden_size, activation="relu", name="Encoder")(encoder_input)
        repeat_vector = RepeatVector((OUTPUT_TIMESTEPS - 1), name="Repeater")(encoder_gru)
    
        # Inference Decoder
        decoder_gru = eval(rnn_unit)(hidden_size, activation="relu", return_sequences=True, name="Decoder")
        decoder_output = decoder_gru(repeat_vector, initial_state=encoder_gru)
    
        # Adjust decoder_output shape
        dense_output_adjust = TimeDistributed(Dense(OUTPUT_VARIABLES), name="Output_Adjust")
        adjusted_decoder_output = dense_output_adjust(decoder_output)
    
        # Training Decoder
        decoder_gru_tf = eval(rnn_unit)(hidden_size, activation="relu", return_sequences=True, name="Decoder-TF")
        probabilistic_tf_layer = ProbabilisticTeacherForcing(name="Probabilistic_Teacher_Forcing")
        mixed_input = probabilistic_tf_layer([decoder_gt_input, adjusted_decoder_output, mixing_prob_input])
        tf_output = decoder_gru_tf(mixed_input, initial_state=encoder_gru)
        tf_output = dense_output_adjust(tf_output)  # Use dense_output_adjust layer for training output
    
        training_model = Model(inputs=[encoder_input, decoder_gt_input, mixing_prob_input], outputs=tf_output, name="Training")
        inference_model = Model(inputs=encoder_input, outputs=adjusted_decoder_output, name="Inference")
    
        return training_model, inference_model
    training_model, model = build_model()
    def denormalize_y(y_data, y_mean, y_std, y_min, y_max, lat_min=42, lat_max=52, lon_min=-70, lon_max=-50):
        scales = tf.constant([lon_max - lon_min, lat_max - lat_min], dtype=tf.float32)
        biases = tf.constant([lon_min, lat_min], dtype=tf.float32)
    
        # Reverse zero-one normalization and standardization
        y_data = y_data * (y_max - y_min) + y_min
        y_data = y_data * y_std + y_mean
    
        # Reverse initial normalization for longitude and latitude
        return y_data * scales + biases
    def haversine_distance(lon1, lat1, lon2, lat2):
        lon1, lat1, lon2, lat2 = [tf.math.multiply(x, tf.divide(tf.constant(np.pi), 180.)) for x in [lon1, lat1, lon2, lat2]]  # lat and lon to radians
        a = tf.math.square(tf.math.sin((lat2 - lat1) / 2.)) + tf.math.cos(lat1) * tf.math.cos(lat2) * tf.math.square(tf.math.sin((lon2 - lon1) / 2.))
        return 2 * 6371000 * tf.math.asin(tf.math.sqrt(a))  # The earth Radius is 6,371,000 meters
    def custom_loss(y_true, y_pred):
        tf.debugging.check_numerics(y_true, "y_true contains NaNs")
        tf.debugging.check_numerics(y_pred, "y_pred contains NaNs")
    
        # Denormalize true and predicted y
        y_true_denorm = denormalize_y(y_true, y_mean, y_std, y_min, y_max)
        y_pred_denorm = denormalize_y(y_pred, y_mean, y_std, y_min, y_max)
    
        # Compute haversine distance for true and predicted y from the second time-step
        true_dist = haversine_distance(y_true_denorm[:, 1:, 0], y_true_denorm[:, 1:, 1], y_true_denorm[:, :-1, 0], y_true_denorm[:, :-1, 1])
        pred_dist = haversine_distance(y_pred_denorm[:, 1:, 0], y_pred_denorm[:, 1:, 1], y_pred_denorm[:, :-1, 0], y_pred_denorm[:, :-1, 1])
    
        # Convert maximum speed from knots to meters per 5 minutes
        max_speed_m_per_5min = 50 * 1.852 * 1000 * 5 / 60
    
        # Compute the difference in distances
        dist_diff = tf.abs(true_dist - pred_dist)
    
        # Apply penalty if the predicted distance is greater than the maximum possible distance
        dist_diff = tf.where(pred_dist > max_speed_m_per_5min, pred_dist - max_speed_m_per_5min, dist_diff)
    
        # Penalty for the first output coordinate not being the same as the last input
        input_output_diff = haversine_distance(y_true_denorm[:, 0, 0], y_true_denorm[:, 0, 1], y_pred_denorm[:, 0, 0], y_pred_denorm[:, 0, 1])
    
        # Compute RMSE excluding the first element
        rmse = K.sqrt(K.mean(K.square(y_true_denorm[:, 1:, :] - y_pred_denorm[:, 1:, :]), axis=1))
    
        tf.debugging.check_numerics(y_true_denorm, "y_true_denorm contains NaNs")
        tf.debugging.check_numerics(y_pred_denorm, "y_pred_denorm contains NaNs")
        tf.debugging.check_numerics(true_dist, "true_dist contains NaNs")
        tf.debugging.check_numerics(pred_dist, "pred_dist contains NaNs")
        tf.debugging.check_numerics(dist_diff, "dist_diff contains NaNs")
        tf.debugging.check_numerics(input_output_diff, "input_output_diff contains NaNs")
        tf.debugging.check_numerics(rmse, "rmse contains NaNs")
    
        # Final loss with weights
        # return 0.25 * K.mean(input_output_diff) + 0.35 * K.mean(dist_diff) + 0.40 * K.mean(rmse)
        return K.mean(rmse)
    def compile_model(model, learning_rate, clipnorm, jit_compile, skip_summary=False):
        optimizer = AdamW(learning_rate=learning_rate, clipnorm=clipnorm, jit_compile=jit_compile)
        model.compile(optimizer=optimizer, loss=custom_loss, metrics=["mae", "mape"], weighted_metrics=[], jit_compile=jit_compile)
        if not skip_summary: model.summary()  # print a summary of the model architecture
    compile_model(training_model, learning_rate=0.001, clipnorm=1, jit_compile=True)
    compile_model(model, learning_rate=0.001, clipnorm=1, jit_compile=True)
    Model: "Training"
    __________________________________________________________________________________________________
     Layer (type)                   Output Shape         Param #     Connected to                     
    ==================================================================================================
     Encoder_Input (InputLayer)     [(None, 48, 8)]      0           []                               
                                                                                                      
     Encoder (GRU)                  (None, 64)           14208       ['Encoder_Input[0][0]']          
                                                                                                      
     Repeater (RepeatVector)        (None, 95, 64)       0           ['Encoder[0][0]']                
                                                                                                      
     Decoder (GRU)                  (None, 95, 64)       24960       ['Repeater[0][0]',               
                                                                      'Encoder[0][0]']                
                                                                                                      
     Output_Adjust (TimeDistributed  (None, 95, 2)       130         ['Decoder[0][0]',                
     )                                                                'Decoder-TF[0][0]']             
                                                                                                      
     Decoder-GT-Input (InputLayer)  [(None, 95, 2)]      0           []                               
                                                                                                      
     Mixing_Probability (InputLayer  [(None, 1)]         0           []                               
     )                                                                                                
                                                                                                      
     Probabilistic_Teacher_Forcing   (None, 95, 2)       0           ['Decoder-GT-Input[0][0]',       
     (ProbabilisticTeacherForcing)                                    'Output_Adjust[0][0]',          
                                                                      'Mixing_Probability[0][0]']     
                                                                                                      
     Decoder-TF (GRU)               (None, 95, 64)       13056       ['Probabilistic_Teacher_Forcing[0
                                                                     ][0]',                           
                                                                      'Encoder[0][0]']                
                                                                                                      
    ==================================================================================================
    Total params: 52,354
    Trainable params: 52,354
    Non-trainable params: 0
    __________________________________________________________________________________________________
    Model: "Inference"
    __________________________________________________________________________________________________
     Layer (type)                   Output Shape         Param #     Connected to                     
    ==================================================================================================
     Encoder_Input (InputLayer)     [(None, 48, 8)]      0           []                               
                                                                                                      
     Encoder (GRU)                  (None, 64)           14208       ['Encoder_Input[0][0]']          
                                                                                                      
     Repeater (RepeatVector)        (None, 95, 64)       0           ['Encoder[0][0]']                
                                                                                                      
     Decoder (GRU)                  (None, 95, 64)       24960       ['Repeater[0][0]',               
                                                                      'Encoder[0][0]']                
                                                                                                      
     Output_Adjust (TimeDistributed  (None, 95, 2)       130         ['Decoder[0][0]']                
     )                                                                                                
                                                                                                      
    ==================================================================================================
    Total params: 39,298
    Trainable params: 39,298
    Non-trainable params: 0
    __________________________________________________________________________________________________
    def create_callbacks(model_name, monitor="val_loss", factor=0.2, lr_patience=3, ep_patience=12, min_lr=0, verbose=0, restore_best_weights=True, skip_wandb=False):
        return  ([wandb.keras.WandbMetricsLogger()] if not skip_wandb else []) + [#tf.keras.callbacks.TerminateOnNaN(),
                ReduceLROnPlateau(monitor=monitor, factor=factor, patience=lr_patience, min_lr=min_lr, verbose=verbose),
                EarlyStopping(monitor=monitor, patience=ep_patience, verbose=verbose, restore_best_weights=restore_best_weights),
                tf.keras.callbacks.ModelCheckpoint(os.path.join(ROOT, MODELS, model_name), monitor="val_loss", mode="min", save_best_only=True, verbose=verbose)]
    def train_model(model, x_train, y_train, batch_size, epochs, validation_split, model_name):
        run = wandb.init(project="kAISdb", anonymous="allow")  # start the wandb run
    
        # Set the initial mixing probability
        mixing_prob = 0.5
    
        # Update y_train to have the same dimensions as the output
        y_train = y_train[:, :(OUTPUT_TIMESTEPS - 1), :]
    
        # Create the ground truth input for the decoder by appending a padding at the beginning of the sequence
        decoder_ground_truth_input_data = (np.zeros((y_train.shape[0], 1, y_train.shape[2])), y_train[:, :-1, :])
        decoder_ground_truth_input_data = np.concatenate(decoder_ground_truth_input_data, axis=1)
    
        try:
            # Train the model with Teacher Forcing
            with tf.device(tf.test.gpu_device_name()):
                training_model.fit([x_train, decoder_ground_truth_input_data, np.full((x_train.shape[0], 1), mixing_prob)], y_train, batch_size=batch_size, epochs=epochs,
                                verbose=2, validation_split=validation_split, callbacks=create_callbacks(model_name))
                # , sample_weight=straightness_ratios)
        except KeyboardInterrupt as e:
            print("\nRestoring best weights [...]")
            # Load the weights of the teacher-forcing model
            training_model.load_weights(model_name)
            # Transfering the weights to the inference model
            for layer in model.layers:
                if layer.name in [l.name for l in training_model.layers]:
                    layer.set_weights(training_model.get_layer(layer.name).get_weights())
    
        run.finish()  # finish the wandb run
    model_name = "TF-GRU-AE.h5"
    full_path = os.path.join(ROOT, MODELS, model_name)
    if True:#not os.path.exists(full_path):
        train_model(model, x_train, y_train, batch_size=1024,
                    epochs=250, validation_split=0.2,
                    model_name=model_name)
    else:
        training_model.load_weights(full_path)
        for layer in model.layers:  # inference model initialization
            if layer.name in [l.name for l in training_model.layers]:
                layer.set_weights(training_model.get_layer(layer.name).get_weights())
    def evaluate_model(model, x_test, y_test, y_mean, y_std, y_min, y_max, y_pred=None):
    
        def single_trajectory_error(y_test, y_pred, index):
            distances = haversine_distance(y_test[index, :, 0], y_test[index, :, 1], y_pred[index, :, 0], y_pred[index, :, 1])
            return np.min(distances), np.max(distances), np.mean(distances), np.median(distances)
    
        # Modify this function to handle teacher-forced models with 95 output variables instead of 96
        def all_trajectory_error(y_test, y_pred):
            errors = [single_trajectory_error(y_test[:, 1:], y_pred, i) for i in range(y_test.shape[0])]
            min_errors, max_errors, mean_errors, median_errors = zip(*errors)
            return min(min_errors), max(max_errors), np.mean(mean_errors), np.median(median_errors)
    
        def plot_trajectory(x_test, y_test, y_pred, sample_index):
            min_error, max_error, mean_error, median_error = single_trajectory_error(y_test, y_pred, sample_index)
            fig = go.Figure()
    
            fig.add_trace(go.Scatter(x=x_test[sample_index, :, 0], y=x_test[sample_index, :, 1], mode="lines", name="Input Data", line=dict(color="green")))
            fig.add_trace(go.Scatter(x=y_test[sample_index, :, 0], y=y_test[sample_index, :, 1], mode="lines", name="Ground Truth", line=dict(color="blue")))
            fig.add_trace(go.Scatter(x=y_pred[sample_index, :, 0], y=y_pred[sample_index, :, 1], mode="lines", name="Forecasted Trajectory", line=dict(color="red")))
    
            fig.update_layout(title=f"Sample Index: {sample_index} | Distance Errors (in meteres):<br>Min: {min_error:.2f}m, Max: {max_error:.2f}m, "
                                    f"Mean: {mean_error:.2f}m, Median: {median_error:.2f}m", xaxis_title="Longitude", yaxis_title="Latitude",
                              plot_bgcolor="#e4eaf0", paper_bgcolor="#fcfcfc", width=700, height=600)
    
            max_lon, max_lat = -58.705587131108196, 47.89066160591873
            min_lon, min_lat = -61.34247286889181, 46.09201839408127
    
            fig.update_xaxes(range=[min_lon, max_lon])
            fig.update_yaxes(range=[min_lat, max_lat])
    
            return fig
    
        if y_pred is None:
            with tf.device(tf.test.gpu_device_name()):
                y_pred = model.predict(x_test, verbose=0)
        y_pred_o = y_pred  # preserve the result
    
        x_test = denormalize_x(x_test, x_mean, x_std, x_min, x_max)
        y_pred = denormalize_y(y_pred_o, y_mean, y_std, y_min, y_max)
    
        # Modify this line to handle teacher-forced models with 95 output variables instead of 96
        for sample_index in [1000, 2500, 5000, 7500]:
            display(plot_trajectory(x_test, y_test[:, 1:], y_pred, sample_index))
    
        # The metrics require a lower dimension (no impact on the results)
        y_test_reshaped = np.reshape(y_test[:, 1:], (-1, y_test.shape[2]))
        y_pred_reshaped = np.reshape(y_pred, (-1, y_pred.shape[2]))
    
        # Physical Distance Error given in meters
        all_min_error, all_max_error, all_mean_error, all_median_error = all_trajectory_error(y_test, y_pred)
    
        print("\nAll Trajectories Min DE: {:.4f}m".format(all_min_error))
        print("All Trajectories Max DE: {:.4f}m".format(all_max_error))
        print("All Trajectories Mean DE: {:.4f}m".format(all_mean_error))
        print("All Trajectories Median DE: {:.4f}m".format(all_median_error))
    
        # Calculate evaluation metrics on the test data
        r2 = r2_score(y_test_reshaped, y_pred_reshaped)
        mse = mean_squared_error(y_test_reshaped, y_pred_reshaped)
        mae = mean_absolute_error(y_test_reshaped, y_pred_reshaped)
        evs = explained_variance_score(y_test_reshaped, y_pred_reshaped)
        mape = mean_absolute_percentage_error(y_test_reshaped, y_pred_reshaped)
        rmse = np.sqrt(mse)
    
        print(f"\nTest R^2: {r2:.4f}")
        print(f"Test MAE: {mae:.4f}")
        print(f"Test MSE: {mse:.4f}")
        print(f"Test RMSE: {rmse:.4f}")
        print(f"Test MAPE: {mape:.4f}")
        print(f"Test Explained Variance Score: {evs:.4f}")
    
        return y_pred_o
    _ = evaluate_model(model, x_test, y_test, y_mean, y_std, y_min, y_max)
    def save_history(history, model_name):
        history_name = model_name.replace('.h5', '.pkl')
        history_name = os.path.join(ROOT, MODELS, history_name)
        with open(history_name, 'wb') as f:
            pkl.dump(history, f)
    def load_history(model_name):
        history_name = model_name.replace('.h5', '.pkl')
        history_name = os.path.join(ROOT, MODELS, history_name)
        with open(history_name, 'rb') as f:
            history = pkl.load(f)
        return history
    def build_model(rnn_unit="GRU", enc_units_1=64, dec_units_1=64):
        encoder_input = Input(shape=(INPUT_TIMESTEPS, INPUT_VARIABLES), name="Encoder_Input")
        decoder_gt_input = Input(shape=((OUTPUT_TIMESTEPS - 1), OUTPUT_VARIABLES), name="Decoder-GT-Input")
        mixing_prob_input = Input(shape=(1,), name="Mixing_Probability")
    
        # Encoder
        encoder_gru = eval(rnn_unit)(enc_units_1, activation="relu", name="Encoder")(encoder_input)
        repeat_vector = RepeatVector((OUTPUT_TIMESTEPS - 1), name="Repeater")(encoder_gru)
    
        # Inference Decoder
        decoder_gru = eval(rnn_unit)(dec_units_1, activation="relu", return_sequences=True, name="Decoder")
        decoder_output = decoder_gru(repeat_vector, initial_state=encoder_gru)
    
        # Adjust decoder_output shape
        dense_output_adjust = TimeDistributed(Dense(OUTPUT_VARIABLES), name="Output_Adjust")
        adjusted_decoder_output = dense_output_adjust(decoder_output)
    
        # Training Decoder
        decoder_gru_tf = eval(rnn_unit)(dec_units_1, activation="relu", return_sequences=True, name="Decoder-TF")
        probabilistic_tf_layer = ProbabilisticTeacherForcing(name="Probabilistic_Teacher_Forcing")
        mixed_input = probabilistic_tf_layer([decoder_gt_input, adjusted_decoder_output, mixing_prob_input])
        tf_output = decoder_gru_tf(mixed_input, initial_state=encoder_gru)
        tf_output = dense_output_adjust(tf_output)  # Use dense_output_adjust layer for training output
    
        training_model = Model(inputs=[encoder_input, decoder_gt_input, mixing_prob_input], outputs=tf_output, name="Training")
        inference_model = Model(inputs=encoder_input, outputs=adjusted_decoder_output, name="Inference")
    
        return training_model, inference_model
    def objective(hyperparams, x_train, y_train, straightness_ratios, model_prefix):
        # Get the best hyperparameters from the optimization results
        enc_units_1 = hyperparams["enc_units_1"]
        dec_units_1 = hyperparams["dec_units_1"]
        mixing_prob = hyperparams["mixing_prob"]
        lr = hyperparams["learning_rate"]
    
        # Create the model name using the best hyperparameters
        model_name = f"{model_prefix}-{enc_units_1}-{dec_units_1}-{mixing_prob}-{lr}.h5"
        full_path = os.path.join(ROOT, MODELS, model_name)  # best model full path
    
        # Check if the model results file with this name already exists
        if not os.path.exists(full_path.replace(".h5", ".pkl")):
            print(f"Saving under {model_name}.")
    
            # Define the model architecture
            training_model, _ = build_model(enc_units_1=enc_units_1, dec_units_1=dec_units_1)
            compile_model(training_model, learning_rate=lr, clipnorm=1, jit_compile=True, skip_summary=True)
    
            # Update y_train to have the same dimensions as the output
            y_train = y_train[:, :(OUTPUT_TIMESTEPS - 1), :]
    
            # Create the ground truth input for the decoder by appending a padding at the beginning of the sequence
            decoder_ground_truth_input_data = (np.zeros((y_train.shape[0], 1, y_train.shape[2])), y_train[:, :-1, :])
            decoder_ground_truth_input_data = np.concatenate(decoder_ground_truth_input_data, axis=1)
    
            # Train the model on the data, using GPU if available
            with tf.device(tf.test.gpu_device_name()):
                history = training_model.fit([x_train, decoder_ground_truth_input_data, np.full((x_train.shape[0], 1), mixing_prob)], y_train,
                                             batch_size=10240, epochs=250, validation_split=.2, verbose=0,
                                             workers=multiprocessing.cpu_count(), use_multiprocessing=True,
                                             callbacks=create_callbacks(model_name, skip_wandb=True))
                                             #, sample_weight=straightness_ratios)
                # Save the training history
                save_history(history.history, model_name)
    
            # Clear the session to release resources
            del training_model; tf.keras.backend.clear_session()
        else:
            print("Loading pre-trained weights.")
            history = load_history(model_name)
    
        if type(history) == dict:  # validation loss of the model
            return {"loss": history["val_loss"][-1], "status": STATUS_OK}
        else: return {"loss": history.history["val_loss"][-1], "status": STATUS_OK}
    def optimize_hyperparameters(max_evals, model_prefix, x_train, y_train, sample_size=5000):
    
        def build_space(n_min=2, n_steps=9):
            # Defining a custom 2^N range function
            n_range = lambda n_min, n_steps: np.array(
                [2**n for n in range(n_min, n_steps) if 2**n >= n_min])
    
            # Defining the unconstrained search space
            encoder_1_range = n_range(n_min, n_steps)
            decoder_1_range = n_range(n_min, n_steps)
            learning_rate_range = [.01, .001, .0001]
            mixing_prob_range = [.25, .5, .75]
    
            # Enforcinf contraints to the search space
            enc_units_1 = np.random.choice(encoder_1_range)
            dec_units_1 = np.random.choice(decoder_1_range[np.where(decoder_1_range == enc_units_1)])
            learning_rate = np.random.choice(learning_rate_range)
            mixing_prob = np.random.choice(mixing_prob_range)
    
            # Returns a single element of the search space
            return dict(enc_units_1=enc_units_1, dec_units_1=dec_units_1, learning_rate=learning_rate, mixing_prob=mixing_prob)
    
        # Select the search space based on a pre-set sampled random space
        search_space = hp.choice("hyperparams", [build_space() for _ in range(sample_size)])
        trials = Trials()  # initialize Hyperopt trials
    
        # Define the objective function for Hyperopt
        fn = lambda hyperparams: objective(hyperparams, x_train, y_train, straightness_ratios, model_prefix)
    
        # Perform Hyperopt optimization and find the best hyperparameters
        best = fmin(fn=fn, space=search_space, algo=tpe.suggest, max_evals=max_evals, trials=trials)
        best_hyperparams = space_eval(search_space, best)
    
        # Get the best hyperparameters from the optimization results
        enc_units_1 = best_hyperparams["enc_units_1"]
        dec_units_1 = best_hyperparams["dec_units_1"]
        mixing_prob = best_hyperparams["mixing_prob"]
        lr = best_hyperparams["learning_rate"]
    
        # Create the model name using the best hyperparameters
        model_name = f"{model_prefix}-{enc_units_1}-{dec_units_1}-{mixing_prob}-{lr}.h5"
        full_path = os.path.join(ROOT, MODELS, model_name)  # best model full path
        t_model, i_model = build_model(enc_units_1=enc_units_1, dec_units_1=dec_units_1)
        t_model = tf.keras.models.load_model(full_path)
        for layer in i_model.layers:  # inference model initialization
            if layer.name in [l.name for l in t_model.layers]:
                layer.set_weights(t_model.get_layer(layer.name).get_weights())
    
        print(f"Best hyperparameters:")
        print(f"  Encoder units 1: {enc_units_1}")
        print(f"  Decoder units 1: {dec_units_1}")
        print(f"  Mixing proba.: {mixing_prob}")
        print(f"  Learning rate: {lr}")
        return i_model
    max_evals, model_prefix = 100, "TF-GRU"
    # best_model = optimize_hyperparameters(max_evals, model_prefix, x_train, y_train)
    # [NOTE] YOU CAN SKIP THIS STEP BY LOADING THE PRE-TRAINED WEIGHTS ON THE NEXT CELL.
    def find_best_model(root_folder, model_prefix):
        best_model_name, best_val_loss = None, float('inf')
    
        for f in os.listdir(root_folder):
            if (f.endswith(".h5") and f.startswith(model_prefix)):
                try:
                    history = load_history(f)
    
                    # Get the validation loss
                    if type(history) == dict:
                        val_loss = history["val_loss"][-1]
                    else: val_loss = history.history["val_loss"][-1]
    
                    # Storing the best model
                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        best_model_name = f
                except: pass
    
        # Load the best model
        full_path = os.path.join(ROOT, MODELS, best_model_name)
        t_model, i_model = build_model(enc_units_1=int(best_model_name.split("-")[2]), dec_units_1=int(best_model_name.split("-")[3]))
        t_model = tf.keras.models.load_model(full_path, custom_objects={"ProbabilisticTeacherForcing": ProbabilisticTeacherForcing})
        for layer in i_model.layers:  # inference model initialization
            if layer.name in [l.name for l in t_model.layers]:
                layer.set_weights(t_model.get_layer(layer.name).get_weights())
        # Print summary of the best model
        print(f"Loss: {best_val_loss}")
        i_model.summary()
    
        return i_model
    best_model = find_best_model(os.path.join(ROOT, MODELS), model_prefix)
    _ = evaluate_model(best_model, x_test, y_test, y_mean, y_std, y_min, y_max)
    def permutation_feature_importance(model, x_test, y_test, metric):
    
        # Function to calculate permutation feature importance
        def PFI(model, x, y_true, metric):
            # Reshape the true values for easier comparison with predictions
            y_true = np.reshape(y_true, (-1, y_true.shape[2]))
    
            # Predict using the model and reshape the predicted values
            with tf.device(tf.test.gpu_device_name()):
                y_pred = model.predict(x, verbose=0)
            y_pred = np.reshape(y_pred, (-1, y_pred.shape[2]))
    
            # Calculate the baseline score using the given metric
            baseline_score = metric(y_true, y_pred)
    
            # Initialize an array for feature importances
            feature_importances = np.zeros(x.shape[2])
    
            # Calculate the importance for each feature
            for feature_idx in range(x.shape[2]):
                x_permuted = x.copy()
                x_permuted[:, :, feature_idx] = np.random.permutation(x[:, :, feature_idx])
    
                # Predict using the permuted input and reshape the predicted values
                with tf.device(tf.test.gpu_device_name()):
                    y_pred_permuted = model.predict(x_permuted, verbose=0)
                y_pred_permuted = np.reshape(y_pred_permuted, (-1, y_pred_permuted.shape[2]))
    
                # Calculate the score with permuted input
                permuted_score = metric(y_true, y_pred_permuted)
    
                # Compute the feature importance as the difference between permuted and baseline scores
                feature_importances[feature_idx] = permuted_score - baseline_score
    
            return feature_importances
    
        feature_importances = PFI(model, x_test, y_test, metric)
    
        # Prepare the data for plotting (require a dataframe)
        feature_names = ["Longitude", "Latitude", "COG", "SOG"]
        feature_importance_df = pd.DataFrame({"features": feature_names, "importance": feature_importances})
    
        # Create the bar plot with Altair
        bar_plot = alt.Chart(feature_importance_df).mark_bar(size=40, color="mediumblue", opacity=0.8).encode(
            x=alt.X("features:N", title="Features", axis=alt.Axis(labelFontSize=12, titleFontSize=14)),
            y=alt.Y("importance:Q", title="Permutation Importance", axis=alt.Axis(labelFontSize=12, titleFontSize=14)),
        ).properties(title=alt.TitleParams(text="Feature Importance", fontSize=16, fontWeight="bold"), width=400, height=300)
    
        return bar_plot, feature_importances
    permutation_feature_importance(best_model, x_test, y_test, mean_absolute_error)[0].display()
    def sensitivity_analysis(model, x_sample, perturbation_range=(-0.1, 0.1), num_steps=10, plot_nrows=4):
        # Get the number of features and outputs
        num_features = x_sample.shape[1]
        num_outputs = model.output_shape[-1] * model.output_shape[-2]
    
        # Create an array of perturbations
        perturbations = np.linspace(perturbation_range[0], perturbation_range[1], num_steps)
    
        # Initialize sensitivity array
        sensitivity = np.zeros((num_features, num_outputs, num_steps))
    
        # Get the original prediction for the input sample
        original_prediction = model.predict(x_sample.reshape(1, -1, 4), verbose=0).reshape(-1)
    
        # Iterate over input features and perturbations
        for feature_idx in range(num_features):
            for i, perturbation in enumerate(perturbations):
                # Create a perturbed version of the input sample
                perturbed_sample = x_sample.copy()
                perturbed_sample[:, feature_idx] += perturbation
    
                # Get the prediction for the perturbed input sample
                perturbed_prediction = model.predict(perturbed_sample.reshape(1, -1, 4), verbose=0).reshape(-1)
    
                # Calculate the absolute prediction change and store it in the sensitivity array
                sensitivity[feature_idx, :, i] = np.abs(perturbed_prediction - original_prediction)
    
        # Determine the number of rows and columns in the plot
        ncols = 6
        nrows = max(min(plot_nrows, math.ceil(num_outputs / ncols)), 1)
    
        # Define feature names
        feature_names = ["Longitude", "Latitude", "COG", "SOG"]
    
        # Create the sensitivity plot
        fig, axs = plt.subplots(nrows, ncols, figsize=(18, 3 * nrows), sharex=True, sharey=True)
        axs = axs.ravel()
    
        output_idx = 0
        for row in range(nrows):
            for col in range(ncols):
                if output_idx < num_outputs:
                    # Plot sensitivity curves for each feature
                    for feature_idx in range(num_features):
                        axs[output_idx].plot(perturbations, sensitivity[feature_idx, output_idx], label=f'{feature_names[feature_idx]}')
                    # Set the title for each subplot
                    axs[output_idx].set_title(f'Output {output_idx // 2 + 1}, {"Longitude" if output_idx % 2 == 0 else "Latitude"}')
                    output_idx += 1
    
        # Set common labels and legend
        fig.text(0.5, 0.04, 'Perturbation', ha='center', va='center')
        fig.text(0.06, 0.5, 'Absolute Prediction Change', ha='center', va='center', rotation='vertical')
        handles, labels = axs[0].get_legend_handles_labels()
        fig.legend(handles, labels, loc='upper center', ncol=num_features, bbox_to_anchor=(.5, .87))
    
        plt.tight_layout()
        plt.subplots_adjust(top=0.8, bottom=0.1, left=0.1, right=0.9)
        plt.show()
    
        return sensitivity
    x_sample = x_test[100]  # Select a sample from the test set
    sensitivity = sensitivity_analysis(best_model, x_sample)
    def visualize_intermediate_representations(model, x_test_subset, y_test_subset, n_neighbors=15, min_dist=0.1, n_components=2):
        # Extract intermediate representations from your model
        intermediate_layer_model = keras.Model(inputs=model.input, outputs=model.layers[-2].output)
        intermediate_output = intermediate_layer_model.predict(x_test_subset, verbose=0)
    
        # Flatten the last two dimensions of the intermediate_output
        flat_intermediate_output = intermediate_output.reshape(intermediate_output.shape[0], -1)
    
        # UMAP
        reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, random_state=seed_value)
        umap_output = reducer.fit_transform(flat_intermediate_output)
    
        # Convert y_test_subset to strings
        y_test_str = np.array([str(label) for label in y_test_subset])
    
        # Map string labels to colors
        unique_labels = np.unique(y_test_str)
        colormap = plt.cm.get_cmap('viridis', len(unique_labels))
        label_to_color = {label: colormap(i) for i, label in enumerate(unique_labels)}
        colors = np.array([label_to_color[label] for label in y_test_str])
    
        # Create plot with Matplotlib
        fig, ax = plt.subplots(figsize=(10, 8))
        sc = ax.scatter(umap_output[:, 0], umap_output[:, 1], c=colors, s=5)
        ax.set_title("UMAP Visualization", fontsize=14, fontweight="bold")
        ax.set_xlabel("X Dimension", fontsize=12)
        ax.set_ylabel("Y Dimension", fontsize=12)
        ax.grid(True, linestyle='--', alpha=0.5)
    
        # Add a colorbar to the plot
        sm = plt.cm.ScalarMappable(cmap=colormap, norm=plt.Normalize(vmin=0, vmax=len(unique_labels)-1))
        sm.set_array([])
        cbar = plt.colorbar(sm, ticks=range(len(unique_labels)), ax=ax)
        cbar.ax.set_yticklabels(unique_labels)
        cbar.set_label("MMSIs")
        plt.show()
    visualize_intermediate_representations(best_model, x_test[:10000], mmsi_test[:10000], n_neighbors=10, min_dist=0.5)
    import io
    import json
    import random
    from collections import defaultdict
    from datetime import datetime, timedelta
    
    import matplotlib.pyplot as plt
    import numpy as np
    import pyproj
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
    from torch import nn
    from torch.utils.data import DataLoader, TensorDataset
    
    import joblib
    import aisdb
    from aisdb import DBConn, database
    from aisdb.database import sqlfcn_callbacks
    from aisdb.database.sqlfcn_callbacks import in_timerange
    from aisdb.database.dbconn import PostgresDBConn
    from aisdb.track_gen import TrackGen
    
    import cartopy.crs as ccrs
    import cartopy.feature as cfeature
    DB_CONNECTION = "/home/sqlite_database_file.db" # replace with your data path
    START_DATE = datetime(2018, 8, 1, hour=0) # starting at 12 midnight on 1st August 2018
    END_DATE = datetime(2018, 8, 4, hour=2) #Ending at 2:00 am on 3rd August 2018
    XMIN, YMIN, XMAX, YMAX =-64.828126,46.113933,-58.500001,49.619290 #Sample coordinates - x refers to longitude and y to latitude
    
    # database connection
    dbconn = DBConn(dbpath=DB_CONNECTION)
    
    # Generating query to extract data between a given time range
    qry = aisdb.DBQuery(dbconn=dbconn, callback=in_timerange,
                          start = START_DATE,
                          end=END_DATE,
                          xmin=XMIN, xmax=XMAX, ymin=YMIN, ymax=YMAX,
                          )
    rowgen = qry.gen_qry(verbose=True) # generating query
    tracks = TrackGen(rowgen, decimate=False) # Convert rows into tracks
    # rowgen_ = qry.gen_qry(reaggregate_static=True, verbose=True) if you want metadata
    
    #To not get an overfitted model, lets chose data from a completely different date to test
    
    TEST_START_DATE =  datetime(2018, 8, 5, hour=0)
    TEST_END_DATE =  datetime(2018,8,6, hour= 8)
    
    test_qry = aisdb.DBQuery(dbconn=dbconn, callback=in_timerange,
                                 start = TEST_START_DATE,
                                 end = TEST_END_DATE,
                                 xmin=XMIN, xmax=XMAX, ymin=YMIN, ymax=YMAX
                            )
    test_tracks = TrackGen( (test_qry.gen_qry(verbose=True)), decimate=False)
    # --- Projection: Lat/Lon -> Cartesian (meters) ---
    proj = pyproj.Proj(proj='utm', zone=20, ellps='WGS84')
    def preprocess_aisdb_tracks(tracks_gen, proj, sog_scaler=None, feature_scaler=None, fit_scaler=False):
        # Keep as generator for AISdb functions
        tracks_gen = aisdb.remove_pings_wrt_speed(tracks_gen, 0.1)
        tracks_gen = aisdb.encode_greatcircledistance(tracks_gen,
                                                      distance_threshold=50000,
                                                      minscore=1e-5,
                                                      speed_threshold=50)
        tracks_gen = aisdb.interp_time(tracks_gen, step=timedelta(minutes=5))
    
        # Convert generator to list AFTER all AISdb steps
        tracks = list(tracks_gen)
    
        # Group by MMSI
        tracks_by_mmsi = defaultdict(list)
        for track in tracks:
            tracks_by_mmsi[track['mmsi']].append(track)
    
        # Keep only MMSIs with tracks >= 100 points
        valid_tracks = []
        for mmsi, mmsi_tracks in tracks_by_mmsi.items():
            if all(len(t['time']) >= 100 for t in mmsi_tracks):
                valid_tracks.extend(mmsi_tracks)
    
        # Convert to DataFrame
        rows = []
        for track in valid_tracks:
            mmsi = track['mmsi']
            sog = track.get('sog', [np.nan]*len(track['time']))
            cog = track.get('cog', [np.nan]*len(track['time']))
    
            for i in range(len(track['time'])):
                x, y = proj(track['lon'][i], track['lat'][i])
                cog_rad = np.radians(cog[i]) if cog[i] is not None else np.nan
                rows.append({
                    'mmsi': mmsi,
                    'x': x,
                    'y': y,
                    'sog': sog[i],
                    'cog_sin': np.sin(cog_rad) if not np.isnan(cog_rad) else np.nan,
                    'cog_cos': np.cos(cog_rad) if not np.isnan(cog_rad) else np.nan,
                    'timestamp': pd.to_datetime(track['time'][i], errors='coerce')
                })
    
        df = pd.DataFrame(rows)
    
        # Drop rows with NaNs
        df = df.replace([np.inf, -np.inf], np.nan)
        df = df.dropna(subset=['x', 'y', 'sog', 'cog_sin', 'cog_cos'])
    
        # Scale features
        feature_cols = ['x', 'y', 'sog','cog_sin','cog_cos']  # only scale directions
        if fit_scaler:
            sog_scaler = RobustScaler()
            df['sog_scaled'] = sog_scaler.fit_transform(df[['sog']])
            feature_scaler = RobustScaler()
            df[feature_cols] = feature_scaler.fit_transform(df[feature_cols])
        else:
            df['sog_scaled'] = sog_scaler.transform(df[['sog']])
            df[feature_cols] = feature_scaler.transform(df[feature_cols])
    
        return df, sog_scaler, feature_scaler
    train_qry = aisdb.DBQuery(dbconn=dbconn, callback=in_timerange,
                              start=START_DATE, end=END_DATE,
                              xmin=XMIN, xmax=XMAX, ymin=YMIN, ymax=YMAX)
    train_gen = TrackGen(train_qry.gen_qry(verbose=True), decimate=False)
    
    test_qry = aisdb.DBQuery(dbconn=dbconn, callback=in_timerange,
                             start=TEST_START_DATE, end=TEST_END_DATE,
                             xmin=XMIN, xmax=XMAX, ymin=YMIN, ymax=YMAX)
    test_gen = TrackGen(test_qry.gen_qry(verbose=True), decimate=False)
    
    # --- Preprocess ---
    train_df, sog_scaler, feature_scaler = preprocess_aisdb_tracks(train_gen, proj, fit_scaler=True)
    test_df, _, _ = preprocess_aisdb_tracks(test_gen, proj, sog_scaler=sog_scaler,
                                            feature_scaler=feature_scaler, fit_scaler=False)
    def create_sequences(df, features, input_size=80, output_size=2, step=1):
        X_list, Y_list = [], []
        for mmsi in df['mmsi'].unique():
            sub = df[df['mmsi']==mmsi].sort_values('timestamp')[features].to_numpy()
            for i in range(0, len(sub)-input_size-output_size+1, step):
                X_list.append(sub[i:i+input_size])
                Y_list.append(sub[i+input_size:i+input_size+output_size])
        return torch.tensor(X_list, dtype=torch.float32), torch.tensor(Y_list, dtype=torch.float32)
    
    features = ['x','y','cog_sin','cog_cos','sog_scaled']
    
    mmsis = train_df['mmsi'].unique()
    train_mmsi, val_mmsi = train_test_split(mmsis, test_size=0.2, random_state=42, shuffle=True)
    
    train_X, train_Y = create_sequences(train_df[train_df['mmsi'].isin(train_mmsi)], features)
    val_X, val_Y     = create_sequences(train_df[train_df['mmsi'].isin(val_mmsi)], features)
    test_X, test_Y   = create_sequences(test_df, features)
    # --- save datasets ---
    torch.save({
        'train_X': train_X, 'train_Y': train_Y,
        'val_X': val_X, 'val_Y': val_Y,
        'test_X': test_X, 'test_Y': test_Y
    }, 'datasets_seq2seq_cartesian5.pt')
    
    # --- save scalers ---
    joblib.dump(feature_scaler, "feature_scaler.pkl")
    joblib.dump(sog_scaler, "sog_scaler.pkl")
    
    # --- save projection parameters ---
    proj_params = {'proj': 'utm', 'zone': 20, 'ellps': 'WGS84'}
    with open("proj_params.json", "w") as f:
        json.dump(proj_params, f)
    
    data = torch.load('datasets_seq2seq_cartesian5.pt')
    
    # scalers
    feature_scaler = joblib.load("feature_scaler.pkl")
    sog_scaler = joblib.load("sog_scaler.pkl")
    
    # projection
    with open("proj_params.json", "r") as f:
        proj_params = json.load(f)
    proj = pyproj.Proj(**proj_params)
    
    train_ds = TensorDataset(data['train_X'], data['train_Y'])
    val_ds = TensorDataset(data['val_X'], data['val_Y'])
    test_ds = TensorDataset(data['test_X'], data['test_Y'])
    
    batch_size = 64
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    test_dl = DataLoader(test_ds, batch_size=batch_size)
    class Seq2SeqLSTM(nn.Module):
        def __init__(self, input_size, hidden_size, input_steps, output_steps):
            super().__init__()
            self.input_steps = input_steps
            self.output_steps = output_steps
    
            # Encoder
            self.encoder = nn.LSTM(input_size, hidden_size, num_layers=2, dropout=0.3, batch_first=True)
    
            # Decoder
            self.decoder = nn.LSTMCell(input_size, hidden_size)
            self.attn = nn.Linear(hidden_size + hidden_size, input_steps)  # basic Bahdanau-ish
            self.attn_combine = nn.Linear(hidden_size + input_size, input_size)
    
            # Output projection (predict residual proposal for all features)
            self.output_layer = nn.Sequential(
                nn.Linear(hidden_size, hidden_size // 2),
                nn.ReLU(),
                nn.Linear(hidden_size // 2, input_size)
            )
    
        def forward(self, x, target_seq=None, teacher_forcing_ratio=0.5):
            """
            x: [B, T_in, F]
            target_seq: [B, T_out, F] (optional, for teacher forcing)
            returns: [B, T_out, F] predicted absolute features (with x,y = last_obs + residuals)
            """
            batch_size = x.size(0)
            encoder_outputs, (h, c) = self.encoder(x)          # encoder_outputs: [B, T_in, H]
            h, c = h[-1], c[-1]                                # take final layer states -> [B, H]
    
            last_obs = x[:, -1, :]     # [B, F] last observed features (scaled)
            last_xy  = last_obs[:, :2] # [B, 2]
    
            decoder_input = last_obs   # start input is last observed feature vector
            outputs = []
    
            for t in range(self.output_steps):
                # attention weights over encoder outputs: combine h and c to get context
                attn_weights = torch.softmax(self.attn(torch.cat((h, c), dim=1)), dim=1)  # [B, T_in]
                context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)  # [B, H]
    
                dec_in = torch.cat((decoder_input, context), dim=1)  # [B, F+H]
                dec_in = self.attn_combine(dec_in)                   # [B, F]
    
                h, c = self.decoder(dec_in, (h, c))                  # h: [B, H]
    
                residual = self.output_layer(h)                      # [B, F] -- residual proposal for all features
    
                # Apply residual only to x,y (first two dims); other features are predicted directly
                out_xy = residual[:, :2] + last_xy                   # absolute x,y = last_xy + residual
                out_rest = residual[:, 2:]                           # predicted auxiliary features
                out = torch.cat([out_xy, out_rest], dim=1)           # [B, F]
    
                outputs.append(out.unsqueeze(1))                     # accumulate
    
                # Teacher forcing or autoregressive input
                if self.training and (target_seq is not None) and (t < target_seq.size(1)) and (random.random() < teacher_forcing_ratio):
                    decoder_input = target_seq[:, t, :]  # teacher forcing uses ground truth step (scaled)
                    # update last_xy to ground truth (so if teacher forced, we align residual base)
                    last_xy = decoder_input[:, :2]
                else:
                    decoder_input = out
                    last_xy = out[:, :2]                  # update last_xy to predicted absolute xy
    
            return torch.cat(outputs, dim=1)  # [B, T_out, F]
    
    def weighted_coord_loss(pred, target, coord_weight=5.0, reduction='mean'):
        """
        Give higher weight to coordinate errors (first two dims).
        pred/target: [B, T, F]
        """
        coord_loss = F.smooth_l1_loss(pred[..., :2], target[..., :2], reduction=reduction)
        aux_loss   = F.smooth_l1_loss(pred[..., 2:], target[..., 2:], reduction=reduction)
        return coord_weight * coord_loss + aux_loss
    
    
    def xy_smoothness_loss(seq_full):
        """
        Smoothness penalty only on x,y channels of a full sequence.
        seq_full: [B, T, F] (F >= 2)
        Returns scalar L2 of second differences on x,y.
        """
        xy = seq_full[..., :2]              # [B, T, 2]
        v = xy[:, 1:, :] - xy[:, :-1, :]    # velocity [B, T-1, 2]
        a = v[:, 1:, :] - v[:, :-1, :]      # acceleration [B, T-2, 2]
        return (v**2).mean() * 0.05 + (a**2).mean() * 0.5
    
    def train_model(model, loader, val_dl, optimizer, device, epochs=50,
                    coord_weight=5.0, smooth_w_init=1e-3):
        best_loss = float('inf')
        best_state = None
    
        for epoch in range(epochs):
            model.train()
            total_loss = total_data_loss = total_smooth_loss = 0.0
    
            for batch_x, batch_y in loader:
                batch_x = batch_x.to(device)   # [B, T_in, F]
                batch_y = batch_y.to(device)   # [B, T_out, F]
    
                # --- Convert targets to residuals ---
                # residual = y[t] - y[t-1], relative to last obs (x[:,-1])
                y_start = batch_x[:, -1:, :2]   # last observed xy [B,1,2]
                residual_y = batch_y.clone()
                residual_y[..., :2] = batch_y[..., :2] - torch.cat([y_start, batch_y[:, :-1, :2]], dim=1)
    
                optimizer.zero_grad()
    
                # Model predicts residuals
                pred_residuals = model(batch_x, target_seq=residual_y, teacher_forcing_ratio=0.5)  # [B, T_out, F]
    
                # Data loss on residuals (only xy are residualized, other features can stay as-is)
                loss_data = weighted_coord_loss(pred_residuals, residual_y, coord_weight=coord_weight)
    
                # Smoothness: reconstruct absolute xy sequence first
                pred_xy = torch.cumsum(pred_residuals[..., :2], dim=1) + y_start
                full_seq = torch.cat([batch_x[..., :2], pred_xy], dim=1)   # [B, T_in+T_out, 2]
                loss_smooth = xy_smoothness_loss(full_seq)
    
                # Decay smooth weight slightly over epochs
                smooth_weight = smooth_w_init * max(0.1, 1.0 - epoch / 30.0)
                loss = loss_data + smooth_weight * loss_smooth
    
                # Stability guard
                if torch.isnan(loss) or torch.isinf(loss):
                    print("Skipping batch due to NaN/inf loss.")
                    continue
    
                loss.backward()
                optimizer.step()
    
                total_loss += loss.item()
                total_data_loss += loss_data.item()
                total_smooth_loss += loss_smooth.item()
    
            avg_loss = total_loss / len(loader)
            print(f"Epoch {epoch+1:02d} | Total: {avg_loss:.6f} | Data: {total_data_loss/len(loader):.6f} | Smooth: {total_smooth_loss/len(loader):.6f}")
    
            # Validation
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for xb, yb in val_dl:
                    xb, yb = xb.to(device), yb.to(device)
    
                    # Compute residuals for validation
                    y_start = xb[:, -1:, :2]
                    residual_y = yb.clone()
                    residual_y[..., :2] = yb[..., :2] - torch.cat([y_start, yb[:, :-1, :2]], dim=1)
    
                    pred_residuals = model(xb, target_seq=residual_y, teacher_forcing_ratio=0.0)
    
                    data_loss = weighted_coord_loss(pred_residuals, residual_y, coord_weight=coord_weight)
    
                    pred_xy = torch.cumsum(pred_residuals[..., :2], dim=1) + y_start
                    full_seq = torch.cat([xb[..., :2], pred_xy], dim=1)
                    loss_smooth = xy_smoothness_loss(full_seq)
    
                    val_loss += (data_loss + smooth_weight * loss_smooth).item()
    
            val_loss /= len(val_dl)
            print(f"           Val Loss: {val_loss:.6f}")
    
            if val_loss < best_loss:
                best_loss = val_loss
                best_state = model.state_dict()
    
        if best_state is not None:
            torch.save(best_state, "best_model_seq2seq_residual_xy_08302.pth")
            print("✅ Best model saved")
    
    # Setting Seeds are extrememly important in research purposes to make sure the results are reproducible
    SEED = 42
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_X = data['train_X']
    input_steps = 80
    output_steps = 2
    # Collapse batch and time dims to compute global min for each feature
    flat_train_X = train_X.view(-1, train_X.shape[-1])  # shape: [N*T, 4]
    
    
    input_size = 5
    hidden_size = 64
    
    model = Seq2SeqLSTM(input_size, hidden_size, input_steps, output_steps).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.HuberLoss()
    
    train_model(model, train_dl,val_dl, optimizer, device)
    
    def evaluate_with_errors(
        model,
        test_dl,
        proj,
        feature_scaler,
        device,
        num_batches=1,
        outputs_are_residual_xy: bool = True,
        dup_tol: float = 1e-4,
        residual_decode_mode: str = "cumsum",   # "cumsum", "independent", "stdonly"
        residual_std: np.ndarray = None,        # needed if residual_decode_mode="stdonly"
    ):
        """
        Evaluates model and prints/plots errors.
    
        Parameters
        ----------
        outputs_are_residual_xy : bool
            If True, model outputs residuals (scaled). Else, absolute coords (scaled).
        residual_decode_mode : str
            - "cumsum": cumulative sum of residuals (default).
            - "independent": treat each residual as absolute offset from last obs.
            - "stdonly": multiply by provided residual_std instead of scaler.scale_.
        residual_std : np.ndarray
            Std dev for decoding (meters), used only if residual_decode_mode="stdonly".
        """
        model.eval()
        errors_all = []
    
        def inverse_xy_only(xy_scaled, scaler):
            """Inverse transform x,y only (works with StandardScaler/RobustScaler)."""
            if hasattr(scaler, "mean_"):
                center = scaler.mean_[:2]
                scale = scaler.scale_[:2]
            elif hasattr(scaler, "center_"):
                center = scaler.center_[:2]
                scale = scaler.scale_[:2]
            else:
                raise ValueError("Scaler type not supported for XY inversion")
            return xy_scaled * scale + center
    
        def haversine(lon1, lat1, lon2, lat2):
            """Distance (m) between two lon/lat points using haversine formula."""
            R = 6371000.0
            lon1, lat1, lon2, lat2 = map(np.radians, [lon1, lat1, lon2, lat2])
            dlon = lon2 - lon1
            dlat = lat2 - lat1
            a = np.sin(dlat/2.0)**2 + np.cos(lat1)*np.cos(lat2)*np.sin(dlon/2.0)**2
            return 2 * R * np.arcsin(np.sqrt(a))
    
        with torch.no_grad():
            batches = 0
            for xb, yb in test_dl:
                xb = xb.to(device); yb = yb.to(device)
                pred = model(xb, teacher_forcing_ratio=0.0)  # [B, T_out, F]
    
                # first sample for visualization/diagnostics
                input_seq = xb[0].cpu().numpy()
                real_seq  = yb[0].cpu().numpy()
                pred_seq  = pred[0].cpu().numpy()
    
                input_xy_s = input_seq[:, :2]
                real_xy_s  = real_seq[:, :2]
                pred_xy_s  = pred_seq[:, :2]
    
                # reconstruct predicted absolute trajectory
                if outputs_are_residual_xy:
                    last_obs_xy_s = input_xy_s[-1]
                    last_obs_xy_m = inverse_xy_only(last_obs_xy_s, feature_scaler)
    
                    # choose decoding mode
                    if residual_decode_mode == "stdonly":
                        if residual_std is None:
                            raise ValueError("residual_std must be provided for 'stdonly' mode")
                        pred_resid_m = pred_xy_s * residual_std
                        pred_xy_m = np.cumsum(pred_resid_m, axis=0) + last_obs_xy_m
    
                    elif residual_decode_mode == "independent":
                        if hasattr(feature_scaler, "scale_"):
                            scale = feature_scaler.scale_[:2]
                        else:
                            raise ValueError("Scaler missing scale_ for residual decoding")
                        pred_resid_m = pred_xy_s * scale
                        pred_xy_m = pred_resid_m + last_obs_xy_m  # no cumsum
    
                    else:  # default: "cumsum"
                        if hasattr(feature_scaler, "scale_"):
                            scale = feature_scaler.scale_[:2]
                        else:
                            raise ValueError("Scaler missing scale_ for residual decoding")
                        pred_resid_m = pred_xy_s * scale
                        pred_xy_m = np.cumsum(pred_resid_m, axis=0) + last_obs_xy_m
    
                else:
                    pred_xy_m = inverse_xy_only(pred_xy_s, feature_scaler)
    
                # Handle duplicate first target
                trimmed_real_xy_s = real_xy_s.copy()
                dropped_duplicate = False
                if trimmed_real_xy_s.shape[0] >= 1:
                    if np.allclose(trimmed_real_xy_s[0], input_xy_s[-1], atol=dup_tol):
                        trimmed_real_xy_s = trimmed_real_xy_s[1:]
                        dropped_duplicate = True
    
                # Align lengths
                len_pred = pred_xy_m.shape[0]
                len_real = trimmed_real_xy_s.shape[0]
                min_len = min(len_pred, len_real)
                if min_len == 0:
                    print("No overlapping horizon — skipping batch.")
                    batches += 1
                    if batches >= num_batches:
                        break
                    else:
                        continue
    
                pred_xy_m = pred_xy_m[:min_len]
                real_xy_m = inverse_xy_only(trimmed_real_xy_s[:min_len], feature_scaler)
                input_xy_m = inverse_xy_only(input_xy_s, feature_scaler)
    
                # debug
                print("\nDEBUG (scaled -> unscaled):")
                print("last_obs_scaled:", input_xy_s[-1])
                print("real_first_scaled:", real_xy_s[0])
                print("pred_first_scaled_delta:", pred_xy_s[0])
                print("dropped_duplicate_target:", dropped_duplicate)
                print("pred length:", pred_xy_m.shape[0], "real length:", real_xy_m.shape[0])
                print("last_obs_unscaled:", input_xy_m[-1])
                print("pred_first_unscaled:", pred_xy_m[0])
    
                # lon/lat conversion
                lon_in,  lat_in   = proj(input_xy_m[:,0],  input_xy_m[:,1],  inverse=True)
                lon_real, lat_real= proj(real_xy_m[:,0],   real_xy_m[:,1],   inverse=True)
                lon_pred, lat_pred= proj(pred_xy_m[:,0],   pred_xy_m[:,1],   inverse=True)
    
                # comparison table
                print("\n=== Predicted vs True (lat/lon) ===")
                print(f"{'t':>3} | {'lon_true':>9} | {'lat_true':>9} | {'lon_pred':>9} | {'lat_pred':>9} | {'err_m':>9}")
                errors = []
                for t in range(len(lon_real)):
                    err_m = haversine(lon_real[t], lat_real[t], lon_pred[t], lat_pred[t])
                    errors.append(err_m)
                    print(f"{t:3d} | {lon_real[t]:9.5f} | {lat_real[t]:9.5f} | {lon_pred[t]:9.5f} | {lat_pred[t]:9.5f} | {err_m:9.2f}")
    
                errors_all.append(errors)
    
                # plot
                fig = plt.figure(figsize=(8, 6))
                ax = plt.axes(projection=ccrs.PlateCarree())
                
                # set extent dynamically around trajectory
                all_lons = np.concatenate([lon_in, lon_real, lon_pred])
                all_lats = np.concatenate([lat_in, lat_real, lat_pred])
                lon_min, lon_max = all_lons.min() - 0.01, all_lons.max() + 0.01
                lat_min, lat_max = all_lats.min() - 0.01, all_lats.max() + 0.01
                ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
                
                # add map features
                ax.add_feature(cfeature.COASTLINE)
                ax.add_feature(cfeature.LAND, facecolor="lightgray")
                ax.add_feature(cfeature.OCEAN, facecolor="lightblue")
                
                # plot trajectories
                ax.plot(lon_in,   lat_in,   "o-", label="history", transform=ccrs.PlateCarree(),
                        markersize=6, linewidth=2)
                ax.plot(lon_real, lat_real, "o-", label="true",    transform=ccrs.PlateCarree(),
                        markersize=6, linewidth=2)
                ax.plot(lon_pred, lat_pred, "x--", label="pred",   transform=ccrs.PlateCarree(),
                        markersize=8, linewidth=2)
                
                ax.legend()
                plt.show()
    
    
                batches += 1
                if batches >= num_batches:
                    break
    
        # summary
        if errors_all:
            errors_all = np.array(errors_all)
            mean_per_t = errors_all.mean(axis=0)
            print("\n=== Summary (meters) ===")
            for t, v in enumerate(mean_per_t):
                print(f"t={t} mean error: {v:.2f} m")
            print(f"mean over horizon: {errors_all.mean():.2f} m, median: {np.median(errors_all):.2f} m")
    
    # --- device ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = torch.load("datasets_seq2seq_cartesian5.pt")
    
    train_X, train_Y = data['train_X'], data['train_Y']
    val_X, val_Y     = data['val_X'], data['val_Y']
    test_X, test_Y   = data['test_X'], data['test_Y']
    
    
    # --- load scalers ---
    feature_scaler = joblib.load("feature_scaler.pkl")  # fitted on x,y,sog,cog_sin,cog_cos
    sog_scaler     = joblib.load("sog_scaler.pkl")      
    
    # --- load projection ---
    with open("proj_params.json", "r") as f:
        proj_params = json.load(f)
    proj = pyproj.Proj(**proj_params)
    
    # --- rebuild & load model ---
    input_size = train_X.shape[2]   # features per timestep
    input_steps = train_X.shape[1]  # seq length in
    output_steps = train_Y.shape[1] # seq length out
    
    hidden_size = 64    
    num_layers = 2      
    
    best_model = Seq2SeqLSTM(
        input_size=5,
        hidden_size=64,
        input_steps=80,   
        output_steps=2,   
    ).to(device)
    
    best_model.load_state_dict(torch.load(
        "best_model_seq2seq_residual_xy_08302.pth", map_location=device
    ))
    best_model.eval()
    
    
    
    import numpy as np
    
    # helper: inverse xy using your loaded feature_scaler (same as in evaluate)
    def inverse_xy_only_np(xy_scaled, scaler):
        if hasattr(scaler, "mean_"):
            center = scaler.mean_[:2]
            scale  = scaler.scale_[:2]
        elif hasattr(scaler, "center_"):
            center = scaler.center_[:2]
            scale  = scaler.scale_[:2]
        else:
            raise ValueError("Scaler type not supported for XY inversion")
        return xy_scaled * scale + center
    
    
    N = train_X.shape[0]
    all_resids = []
    for i in range(N):
        last_obs_s = train_X[i, -1, :2]               # scaled
        true_s     = train_Y[i, :, :2]                # scaled (T_out,2)
    
        last_obs_m = inverse_xy_only_np(last_obs_s, feature_scaler)
        true_m     = inverse_xy_only_np(true_s, feature_scaler)
    
       
        if true_m.shape[0] == 0:
            continue
        resid0 = true_m[0] - last_obs_m
        if true_m.shape[0] > 1:
            rest = np.diff(true_m, axis=0)
            resids = np.vstack([resid0[None, :], rest])   # shape [T_out, 2]
        else:
            resids = resid0[None, :]
    
        all_resids.append(resids)
    
    # stack to compute std per axis across all time steps and samples
    all_resids_flat = np.vstack(all_resids)  # [sum_T, 2]
    residual_std = np.std(all_resids_flat, axis=0)  # [std_dx_m, std_dy_m]
    print("Computed residual_std (meters):", residual_std)
    
    evaluate_with_errors(
        best_model, test_dl, proj, feature_scaler, device,
        num_batches=1,
        outputs_are_residual_xy=True,
        residual_decode_mode="stdonly",
        residual_std=residual_std
    )
    
    

    hashtag
    Process

    Steps -

    1. Split by time: Breaks continuous vessel tracks into separate segments when there's a significant time gap in the data

    2. Encode track: Applies distance and speed-based encoding to smooth and normalize vessel trajectories

    3. Discretize H3: Converts geographic coordinates into H3 hexagonal grid cells for spatial analysis

    4. Detect stops: Identifies vessel stopping periods based on speed thresholds and duration

    5. Add bathymetry: Incorporates water depth data from GeoTIFF files along vessel routes

    6. Add weather: Integrates weather data (wind, pressure, etc.) with vessel tracking information

    Gap minutes (default: 4320) - Maximum allowed time gap between consecutive AIS points before splitting into new track segments

    Time split (days) (default: 15) - Alternative way to split tracks by specifying the maximum duration in days

    Encode distance (m) (default: 200000) - Maximum distance threshold in meters for track encoding/smoothing

    Encode speed (knots) (default: 50) - Maximum speed threshold in knots for track encoding/filtering

    H3 resolution (default: 6) - Resolution level for H3 hexagonal grid cells (0-15, higher means smaller cells)

    Bathymetry raster - Path to GeoTIFF file containing water depth data

    Weather shortnames - List of weather parameters to include (e.g., '10u,10v,msl' for wind components and sea level pressure)

    Distance split (m) (default: 30000) - Distance threshold in meters for breaking tracks into segments

    Speed split (knots) (default: 30) - Speed threshold in knots for track segmentation

    hashtag
    Static Plot

    The Static Plot feature leverages Matplotlib and Cartopy to create high-quality, publication-ready visualizations of AIS vessel tracks. It offers options for coastline overlays, multiple export formats (PNG/PDF), and precise geographic projections, making it ideal for scientific publications and detailed analysis of maritime traffic patterns.

    hashtag
    Plotly OSM

    Plotly OSM (OpenStreetMap) provides an interactive visualization experience, combining vessel tracking data with detailed map layers. This feature enables dynamic exploration of maritime routes with zoom capabilities, hover information, and the ability to toggle between line and marker modes. The interactive HTML export option makes it perfect for web-based presentations and sharing insights.

    hashtag
    Explain

    The Explain feature harnesses LangChain and Google's Generative AI to provide natural language insights about your AIS data. Simply ask questions about your dataset, and receive detailed explanations about patterns, anomalies, and statistics, making complex maritime data analysis accessible to non-technical stakeholders.

    hashtag
    Gemini

    The Gemini integration takes maritime data visualization to the next level by combining Google's advanced AI model with visual context. It can analyze both your data and generated plots simultaneously, offering comprehensive insights, pattern recognition, and detailed explanations of maritime behavior that might not be immediately apparent to human observers.

    file-download
    49KB
    AIS-BOT 1.py
    arrow-up-right-from-squareOpen
    Preprocessing

    The first step in building a trajectory learning pipeline is preprocessing AIS tracks into a model-friendly format. Raw AIS messages are noisy, irregularly sampled, and inconsistent across vessels, so we need to enforce structure before feeding them into neural networks. The function below does several things in sequence:

    1. Data cleaning – removes spurious pings based on unrealistic speeds, encodes great-circle distances, and interpolates trajectories at fixed 5-minute intervals.

    2. Track filtering – groups data by vessel (MMSI) and keeps only sufficiently long tracks to ensure stable training samples.

    3. Feature extraction – converts lat/lon into projected coordinates (x, y), adds speed over ground (sog), and represents course over ground (cog) as sine/cosine to avoid angular discontinuities.

    4. Delta computation – calculates dx and dy between consecutive timestamps, capturing local motion dynamics.

    5. Scaling – applies RobustScaler to normalize features and deltas while being resilient to outliers (common in AIS data).

    The result is a clean, scaled DataFrame where each row represents a vessel state at a timestamp, enriched with both absolute position features and relative motion features.

    We first query the AIS database for the training and testing periods and geographic bounds, producing generators of raw vessel tracks. These raw tracks are then preprocessed using preprocess_aisdb_tracks, which cleans the data, computes relative motion (dx, dy), scales features, and outputs a ready-to-use DataFrame. Training data fits new scalers, while test data is transformed using the same scalers to ensure consistency.

    The create_sequences function transforms the preprocessed track data into supervised sequences suitable for model training. For each vessel, it slides a fixed-size window over the time series, building input sequences of past absolute features (x, y, dx, dy, cog_sin, cog_cos, sog_scaled) and target sequences of future residual movements (dx, dy). Using this, the dataset is split into training, validation, and test sets, with each set containing sequences ready for direct input into a trajectory prediction model.

    hashtag
    Save the Dataset

    This block saves all processed data and supporting objects needed for training and evaluation. The preprocessed input and target sequences for training, validation, and testing are serialized using PyTorch (datasets_npin.pt). The fitted scalers for features, speed, and residuals are saved with joblib to ensure consistent scaling during inference. Finally, the projection parameters used to convert geographic coordinates to UTM are stored in JSON, allowing consistent coordinate transformations later.

    hashtag
    Model

    Seq2SeqLSTM is the sequence-to-sequence backbone used in the NPINN-based trajectory prediction framework. Within NPINN, the encoder LSTM processes past vessel observations (positions, speed, and course) to produce a hidden representation that captures motion dynamics. The decoder LSTMCell predicts future residuals in x and y, with an attention mechanism that selectively focuses on relevant past information at each step. Predicted residuals are added to the last observed position to reconstruct absolute trajectories.

    This setup enables NPINN to generate smooth, physically consistent multi-step vessel trajectories, leveraging both historical motion patterns and learned dynamics constraints.

    hashtag
    Model Training

    This training loop implements NPINN-based trajectory learning, combining data fidelity with physics-inspired smoothness constraints. The weighted_coord_loss enforces accurate prediction of future x, y positions, while xy_npinn_smoothness_loss encourages smooth velocity and acceleration profiles, reflecting realistic vessel motion.

    By integrating these two objectives, NPINN learns trajectories that are both close to observed data and physically plausible, with the smoothness weight gradually decaying during training to balance learning accuracy with dynamic consistency. Validation is performed each epoch to ensure generalization, and the best-performing model is saved. This approach differentiates NPINN from standard Seq2Seq training by explicitly incorporating motion dynamics into the loss, rather than relying purely on sequence prediction.

    hashtag
    Training

    This training sets a fixed random seed for reproducibility (torch, numpy, random) and enables deterministic CuDNN behavior. It initializes a Seq2SeqLSTM NPINN model to predict future vessel trajectory residuals from past sequences of absolute features. Global min/max of xy coordinates are computed from the training set for NPINN’s smoothness loss normalization. The model is trained on GPU if available using Adam, combining data loss on predicted xy positions with a physics-inspired smoothness penalty to enforce realistic, physically plausible trajectories.

    hashtag
    Evaluate

    This snippet sets up the environment for inference or evaluation of the NPINN Seq2Seq model:

    • Chooses GPU if available.

    • Loads preprocessed datasets (train_X/Y, val_X/Y, test_X/Y).

    • Loads the saved RobustScalers for features, SOG, and deltas to match preprocessing during training.

    • Loads projection parameters to convert lon/lat to projected coordinates consistently.

    • Rebuilds the same Seq2SeqLSTM NPINN model used during training and loads the best saved weights.

    • Puts the model in evaluation mode, ready for predicting future vessel trajectories.

    Essentially, this is the full recovery pipeline for NPINN inference, ensuring consistency with training preprocessing, scaling, and projection.

    This code provides essential postprocessing and geometric utilities for working with NPINN outputs and AIS trajectory data. The inverse_dxdy_np function is designed to convert scaled residuals (dx, dy) back into real-world units (meters) using a previously fitted scaler. It handles both 1D and 2D inputs, making it suitable for batch or single-step predictions. This is particularly useful for interpreting NPINN predictions in absolute physical units rather than in the normalized or scaled feature space, allowing for meaningful evaluation of the model’s accuracy in real-world terms. Using this, the code also computes the standard deviation of the residuals across the training dataset, providing a quantitative measure of typical displacement magnitudes along the x and y axes.

    The code also includes geometry-related helpers to analyze trajectories in geospatial terms. The haversine function calculates the geodesic distance between longitude/latitude points in meters using the haversine formula, with safeguards for numerical stability and invalid inputs. Building on this, the trajectory_length function computes the total length of a vessel’s trajectory, summing distances between consecutive points while handling incomplete or non-finite data gracefully. Together, these utilities allow NPINN outputs to be mapped back to real-world coordinates, facilitate evaluation of trajectory smoothness and accuracy, and provide interpretable metrics for model validation and downstream analysis.

    hashtag
    Evaluate Function

    This function evaluate_with_errors is designed to evaluate NPINN trajectory predictions in a geospatial context and optionally visualize them. It takes a trained model, a test DataLoader, coordinate projection, scalers, and device information. For each batch, it reconstructs the predicted trajectories from residuals (dx, dy), inverts the scaling back to meters, and converts them to absolute positions starting from the last observed point. Different decoding modes (cumsum, independent, stdonly) allow flexibility in how residuals are integrated into absolute trajectories, and it handles cases where the first residual is effectively a duplicate of the last input.

    The evaluation computes per-timestep errors in meters using the haversine formula and tracks differences in trajectory lengths. All errors are summarized with mean and median statistics across the prediction horizon. When plot_map=True, the function generates separate maps for each trajectory, overlaying the true (green) and predicted (red dashed) paths, giving a clear visual inspection of the model’s performance. This approach is directly aligned with NPINN, as it evaluates predictions in physical units and emphasizes smooth, physically plausible trajectory reconstructions.

    hashtag
    Results

    hashtag
    Trajectory Plots

    1. Trajectory 2

      • True length: 521.39 m

      • Predicted length: 508.66 m

      • Difference: 12.73 m

      The predicted path is roughly parallel to the true path but slightly offset. The predicted trajectory underestimates the total path length by ~2.4%, which is a small but noticeable error. The smoothness of the red dashed line indicates that the model is generating physically plausible, consistent trajectories.

    2. Trajectory 5

      • True length: 188.76 m

      • Predicted length: 206.01 m

    hashtag
    Summary Statistics

    • t=0 mean error: 45.03 m The first prediction step already has a ~45 m average discrepancy from the true position, which is common since the model starts accumulating error immediately after the last observed point.

    • t=1 mean error: 80.80 m Error grows with the horizon, reflecting cumulative effects of residual inaccuracies.

    • Mean over horizon: 62.92 m | Median: 61.72 m On average, predictions are within ~60–63 m of the true trajectory at any given time step. Median being close to mean suggests a fairly symmetric error distribution without extreme outliers.

    • Mean trajectory length difference: 11.82 m | Median: 12.73 m Overall, the predicted trajectories’ total lengths are very close to the true lengths, typically within ~12 m, which is less than 3% relative error for most trajectories.

    hashtag
    Interpretation

    • The model captures trajectory trends well but shows small offsets in absolute positions.

    • Errors grow with horizon, which is typical for sequence prediction models using residuals.

    • Smoothness is maintained (no erratic jumps), indicating that the NPINN smoothness regularization is effective.

    • Overall, this is a solid performance for maritime AIS trajectory prediction, especially given the scale of trajectories (hundreds of meters).

    Indexing: a pipeline for scraping data from documentation and indexing it.
    • This usually happens offline.

  • Retrieval and generation: the actual RAG chain, which takes the user query at runtime and retrieves the relevant data from the index, then passes it to the model.

  • The most common complete sequence from raw docs to answer looks like:

    hashtag
    Indexing

    1. Scrape: First, we need to scrape all documentation pages. This includes the GitBook documentation and related pages.

    2. Split: Text splitters break large documents into smaller chunks. This is useful both for indexing data and passing it into a model, since large chunks are harder to search over and won't fit in a model's finite context window.

    3. Store: We need somewhere to store and index our splits, so that they can be searched over later. This is done using the Chroma vector database and embeddings.

    hashtag
    Retrieval and generation

    1. Retrieve: Given a user input, relevant splits are retrieved from Chroma using similarity search.

    2. Generate: An LLM produces an answer in response to a system prompt that combines both the question and the retrieved context.

    hashtag
    Setup

    hashtag
    Installation

    This tutorial requires these dependencies:

    hashtag
    API Keys

    You'll need a GOOGLE LLM API key (or another LLM provider). Set it as an environment variable:

    hashtag
    Components

    We need to select three main components:

    • LLM: We'll use Google's Gemini models through LangChain

    • Embeddings: Hugging Face SentenceTransformers for creating document embeddings

    • Vector Store: Chroma for storing and searching document embeddings

    hashtag
    Preview

    We can create a simple indexing pipeline and RAG chain to do this in about 100 lines of code.

    hashtag
    Detailed walkthrough

    hashtag
    1. Indexing

    Scraping Documentation

    We need to first scrape all the AISViz documentation pages.

    Splitting documents

    Our scraped documents can be quite long, so we need to split them into smaller chunks. We'll use a simple text splitter that breaks documents into chunks of specified size with some overlap.

    Storing documents with SentenceTransformers

    Now we need to create embeddings for our chunks using Hugging Face SentenceTransformers and store them in the Chroma vector database.

    This completes the Indexing portion of the pipeline. At this point, we have a queryable vector store containing the chunked contents of all the documentation with embeddings created by SentenceTransformers. Given a user question, we should be able to return the most relevant snippets.

    hashtag
    2. Retrieval and Generation

    Now let's write the actual application logic. We aim to create a simple function that takes a user question, searches for relevant documents using SentenceTransformers embeddings, and generates an answer using Google Gemini. A high-level breakdown:

    • API Key Setup – Load LLM (Gemini in this example) API key from environment or prompt user.

    • Model Initialization – Wrap Google Gemini (gemini-2.5-flash) with LangChain.

    • Embedding – Convert the user's question into a vector with SentenceTransformers.

    • Retrieval – Query Chroma to fetch the top-k most relevant document chunks.

    • Context Building – Assemble retrieved docs + metadata into a context string.

    • Prompting LLM – Combine system + user prompts and send to LLM.

    • Answer Generation – Return a concise response along with sources and context.

    hashtag
    3. Building a Gradio Interface

    Now let's create a simple web interface using Gradio so others can interact with our chatbot:

    The full code can be found here: https://huggingface.co/spaces/mapslab/AISVIZ-BOT/tree/mainarrow-up-right

    hashtag
    Below is a working chatbot for testing!

    hashtag
    Processing AIS Tracks into Clean Segments

    This function pulls raw AIS data from a database, denoises it, splits tracks into time-consistent segments, filters outliers, and interpolates them at fixed time steps. The result is a set of clean, continuous vessel trajectories ready for embedding

    hashtag
    Loading Region and Grid Shapefiles

    We load the study region (Gulf shapefile) and a hexagonal grid (H3 resolution 6). These will be used to map vessel positions into discrete spatial cells — the “tokens” for our trajectory embedding model.

    Each trajectory is converted from lat/lon coordinates into H3 hexagon IDs at resolution 6. To avoid redundant entries, we deduplicate consecutive identical cells while keeping the timestamp of first entry. The result is a sequence of discrete spatial tokens with time information — the input format for traj2vec.

    Before training embeddings, it’s useful to check how long our AIS trajectories are. The function below computes summary statistics (min, max, mean, percentiles) and plots the distribution of track lengths in terms of H3 cells.

    In our dataset, the distribution is skewed to the left — most vessel tracks are relatively short, with only a few very long trajectories.

    For a simpler visualization, we also plot trajectories in raw lat/lon space without cartographic features. This is handy for debugging and checking if preprocessing (deduplication, interpolation) worked correctly.

    Filtering out some values

    We collect all unique H3 IDs from the trajectories and assign each one an integer index. Just like in NLP, we also reserve special tokens for padding, start, and end of sequence. This turns spatial cells into a vocabulary that our embedding model can work with.

    Each vessel track is then mapped from its H3 sequence into an integer sequence (int_seq). We also convert the H3 cells back into lat/lon pairs for later visualization. At this point, the data is ready to be fed into a traj2vec-style model.

    hashtag
    Train Test Split + Data Saving

    We split the cleaned trajectories into train, validation, and test sets. This ensures our model can be trained, tuned, and evaluated fairly without data leakage.

    Each trajectory is written out in multiple aligned formats:

    • .src → input sequence (all tokens except last)

    • .trg → target sequence (all tokens except first)

    • .lat / .lon → raw geographic coordinates (for visualization)

    • .t → the complete trajectory sequence

    This setup mirrors NLP datasets, where models learn to predict the “next token” in a sequence.

    hashtag
    Some Other Imports for Training

    hashtag
    Training Loop Setup

    We set up utility functions to initialize model weights, save checkpoints during training, and run validation. These ensure training is reproducible and models can be restored later. The train() function loads train/val datasets, defines the loss functions (negative log-likelihood or KL-divergence), and builds the encoder-decoder model with its optimizer and scheduler. If a checkpoint exists, training resumes from where it left off; otherwise, parameters are freshly initialized.

    hashtag
    Generative + Discriminative Losses

    Training uses two objectives:

    • Generative loss (predicting the next trajectory cell, like word prediction in NLP).

    • Discriminative loss (triplet margin loss, ensuring embeddings of similar trajectories are close while different ones are far apart).

    These combined losses help the model learn not only to generate realistic trajectories but also to embed them in a useful vector space.

    The loop runs over iterations, logging training progress, validating periodically, and saving checkpoints. A learning rate scheduler adjusts the optimizer based on validation loss, and early stopping prevents wasted computation when no improvements occur.

    The test() function loads the best checkpoint, evaluates it on the test set, and reports average loss and perplexity. Perplexity is borrowed from NLP — lower values mean the model is more confident in predicting the next trajectory cell.

    ARGS

    Test

    Results-

    The generative model’s performance on the test dataset demonstrates remarkable accuracy in predicting vessel trajectories. Over the course of the evaluation, the cumulative generative loss across batches increased steadily, reflecting the accumulation of prediction errors over the sequences. When aggregated and normalized per token, the average loss was 0.2309, corresponding to a perplexity of approximately 1.26.

    Perplexity is a standard measure in sequence modeling that quantifies how well a probabilistic model predicts a sequence of tokens. A perplexity close to 1 indicates near-deterministic prediction, meaning that the model assigns very high probability to the correct next token in the sequence. In the context of vessel trajectories, this result implies that the model is extremely confident and precise in forecasting the next H3 cell in a track, capturing the underlying spatial and temporal patterns in the data.

    These results are particularly noteworthy because vessel movements are constrained by both geography and navigational behavior. The model effectively learns these patterns, predicting transitions between cells with minimal uncertainty. Achieving such a low perplexity confirms that the preprocessing pipeline, H3 cell encoding, and the sequence modeling architecture are all functioning harmoniously, enabling highly accurate trajectory modeling.

    Overall, the evaluation demonstrates that the model not only generalizes well to unseen tracks but also reliably captures the deterministic structure of vessel movement, providing a robust foundation for downstream tasks such as trajectory prediction, anomaly detection, or maritime route analysis.

    from sklearn.preprocessing import RobustScaler
    
    def preprocess_aisdb_tracks(tracks_gen, proj,
                                sog_scaler=None,
                                feature_scaler=None,
                                delta_scaler=None,
                                fit_scaler=False):
        # --- AISdb cleaning ---
        tracks_gen = aisdb.remove_pings_wrt_speed(tracks_gen, 0.1)
        tracks_gen = aisdb.encode_greatcircledistance(
            tracks_gen,
            distance_threshold=50000,
            minscore=1e-5,
            speed_threshold=50
        )
        tracks_gen = aisdb.interp_time(tracks_gen, step=timedelta(minutes=5))
    
        # --- collect tracks ---
        tracks = list(tracks_gen)
    
        # --- group by MMSI ---
        tracks_by_mmsi = defaultdict(list)
        for track in tracks:
            tracks_by_mmsi[track['mmsi']].append(track)
    
        # --- keep only long-enough tracks ---
        valid_tracks = []
        for mmsi, mmsi_tracks in tracks_by_mmsi.items():
            if all(len(t['time']) >= 100 for t in mmsi_tracks):
                valid_tracks.extend(mmsi_tracks)
    
        # --- flatten into dataframe ---
        rows = []
        for track in valid_tracks:
            mmsi = track['mmsi']
            sog = track.get('sog', [np.nan]*len(track['time']))
            cog = track.get('cog', [np.nan]*len(track['time']))
    
            for i in range(len(track['time'])):
                x, y = proj(track['lon'][i], track['lat'][i])
                cog_rad = np.radians(cog[i]) if cog[i] is not None else np.nan
                rows.append({
                    'mmsi': mmsi,
                    'x': x,
                    'y': y,
                    'sog': sog[i],
                    'cog_sin': np.sin(cog_rad) if not np.isnan(cog_rad) else np.nan,
                    'cog_cos': np.cos(cog_rad) if not np.isnan(cog_rad) else np.nan,
                    'timestamp': pd.to_datetime(track['time'][i], errors='coerce')
                })
    
        df = pd.DataFrame(rows)
    
        # --- clean NaNs ---
        df = df.replace([np.inf, -np.inf], np.nan)
        df = df.dropna(subset=['x', 'y', 'sog', 'cog_sin', 'cog_cos'])
    
        # --- compute deltas per MMSI ---
        df = df.sort_values(["mmsi", "timestamp"])
        df["dx"] = df.groupby("mmsi")["x"].diff().fillna(0)
        df["dy"] = df.groupby("mmsi")["y"].diff().fillna(0)
    
        # --- scale features ---
        feature_cols = ['x', 'y', 'sog', 'cog_sin', 'cog_cos']
        delta_cols   = ['dx', 'dy']
    
        if fit_scaler:
            # sog
            sog_scaler = RobustScaler()
            df['sog_scaled'] = sog_scaler.fit_transform(df[['sog']])
    
            # absolute features
            feature_scaler = RobustScaler()
            df[feature_cols] = feature_scaler.fit_transform(df[feature_cols])
    
            # deltas
            delta_scaler = RobustScaler()
            df[delta_cols] = delta_scaler.fit_transform(df[delta_cols])
        else:
            df['sog_scaled'] = sog_scaler.transform(df[['sog']])
            df[feature_cols] = feature_scaler.transform(df[feature_cols])
            df[delta_cols]   = delta_scaler.transform(df[delta_cols])
    
        return df, sog_scaler, feature_scaler, delta_scaler
    train_qry = aisdb.DBQuery(dbconn=dbconn, callback=in_timerange,
                              start=START_DATE, end=END_DATE,
                              xmin=XMIN, xmax=XMAX, ymin=YMIN, ymax=YMAX)
    train_gen = TrackGen(train_qry.gen_qry(verbose=True), decimate=False)
    
    test_qry = aisdb.DBQuery(dbconn=dbconn, callback=in_timerange,
                             start=TEST_START_DATE, end=TEST_END_DATE,
                             xmin=XMIN, xmax=XMAX, ymin=YMIN, ymax=YMAX)
    test_gen = TrackGen(test_qry.gen_qry(verbose=True), decimate=False)
    
    # --- Preprocess ---
    train_df, sog_scaler, feature_scaler, delta_scaler = preprocess_aisdb_tracks(
        train_gen, proj, fit_scaler=True
    )
    
    test_df, _, _, _ = preprocess_aisdb_tracks(
        test_gen,
        proj,
        sog_scaler=sog_scaler,
        feature_scaler=feature_scaler,
        delta_scaler=delta_scaler,
        fit_scaler=False
    )
    def create_sequences(df, features, input_size=80, output_size=2, step=1):
        """
        Build sequences:
          X: past window of absolute features (x, y, dx, dy, cog_sin, cog_cos, sog_scaled)
          Y: future residuals (dx, dy)
        """
        X_list, Y_list = [], []
        for mmsi in df['mmsi'].unique():
            sub = df[df['mmsi'] == mmsi].sort_values('timestamp').copy()
    
            # build numpy arrays
            feat_arr = sub[features].to_numpy()
            dxdy_arr = sub[['dx', 'dy']].to_numpy()  # residuals already scaled
    
            for i in range(0, len(sub) - input_size - output_size + 1, step):
                # input sequence is absolute features
                X_list.append(feat_arr[i : i + input_size])
    
                # output sequence is residuals immediately after
                Y_list.append(dxdy_arr[i + input_size : i + input_size + output_size])
    
        return torch.tensor(X_list, dtype=torch.float32), torch.tensor(Y_list, dtype=torch.float32)
    
    
    features = ['x', 'y', 'dx', 'dy', 'cog_sin', 'cog_cos', 'sog_scaled']
    
    mmsis = train_df['mmsi'].unique()
    train_mmsi, val_mmsi = train_test_split(mmsis, test_size=0.2, random_state=42, shuffle=True)
    
    train_X, train_Y = create_sequences(train_df[train_df['mmsi'].isin(train_mmsi)], features)
    val_X, val_Y     = create_sequences(train_df[train_df['mmsi'].isin(val_mmsi)], features)
    test_X, test_Y   = create_sequences(test_df, features)
    import torch
    import joblib
    import json
    
    # --- save datasets ---
    torch.save({
        'train_X': train_X, 'train_Y': train_Y,
        'val_X': val_X, 'val_Y': val_Y,
        'test_X': test_X, 'test_Y': test_Y
    }, 'datasets_npin.pt')
    
    # --- save scalers ---
    joblib.dump(feature_scaler, "npinn_feature_scaler.pkl")
    joblib.dump(sog_scaler, "npinn_sog_scaler.pkl")
    joblib.dump(delta_scaler, "npinn_delta_scaler.pkl")   # NEW
    
    
    # --- save projection parameters ---
    proj_params = {'proj': 'utm', 'zone': 20, 'ellps': 'WGS84'}
    with open("npinn_proj_params.json", "w") as f:
        json.dump(proj_params, f)
    
    
    data = torch.load('datasets_npin.pt')
    import torch
    import joblib
    import json
    # scalers
    feature_scaler = joblib.load("npinn_feature_scaler.pkl")
    sog_scaler = joblib.load("npinn_sog_scaler.pkl")
    delta_scaler   = joblib.load("npinn_delta_scaler.pkl")   # NEW
    
    # projection
    with open("npinn_proj_params.json", "r") as f:
        proj_params = json.load(f)
    proj = pyproj.Proj(**proj_params)
    
    train_ds = TensorDataset(data['train_X'], data['train_Y'])
    val_ds = TensorDataset(data['val_X'], data['val_Y'])
    test_ds = TensorDataset(data['test_X'], data['test_Y'])
    
    batch_size = 64
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    test_dl = DataLoader(test_ds, batch_size=batch_size)
    class Seq2SeqLSTM(nn.Module):
        def __init__(self, input_size, hidden_size, input_steps, output_steps):
            super().__init__()
            self.input_steps = input_steps
            self.output_steps = output_steps
    
            # Encoder
            self.encoder = nn.LSTM(input_size, hidden_size, num_layers=2, dropout=0.3, batch_first=True)
    
            # Decoder
            self.decoder = nn.LSTMCell(input_size, hidden_size)
            self.attn = nn.Linear(hidden_size * 2, input_steps)
            self.attn_combine = nn.Linear(hidden_size + input_size, input_size)
    
            # Output only x,y residuals (added to last observed pos)
            self.output_layer = nn.Sequential(
                nn.Linear(hidden_size, hidden_size // 2),
                nn.ReLU(),
                nn.Linear(hidden_size // 2, 2)
            )
    
        def forward(self, x, target_seq=None, teacher_forcing_ratio=0.5):
            batch_size = x.size(0)
            encoder_outputs, (h, c) = self.encoder(x)
            h, c = h[-1], c[-1]
    
            last_obs = x[:, -1, :2]   # last observed absolute x,y
            decoder_input = x[:, -1, :]  # full feature vector
    
            outputs = []
            for t in range(self.output_steps):
                attn_weights = torch.softmax(self.attn(torch.cat((h, c), dim=1)), dim=1)
                context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)
    
                dec_in = torch.cat((decoder_input, context), dim=1)
                dec_in = self.attn_combine(dec_in)
    
                h, c = self.decoder(dec_in, (h, c))
                residual_xy = self.output_layer(h)
    
                # accumulate into absolute xy
                out_xy = residual_xy + last_obs
                outputs.append(out_xy.unsqueeze(1))
    
                # teacher forcing
                if self.training and target_seq is not None and t < target_seq.size(1) and random.random() < teacher_forcing_ratio:
                    decoder_input = torch.cat([target_seq[:, t, :2], decoder_input[:, 2:]], dim=1)
                    last_obs = target_seq[:, t, :2]
                else:
                    decoder_input = torch.cat([out_xy, decoder_input[:, 2:]], dim=1)
                    last_obs = out_xy
    
            return torch.cat(outputs, dim=1)  # (batch, output_steps, 2)
    def weighted_coord_loss(pred, target, coord_weight=5.0, reduction='mean'):
        return F.smooth_l1_loss(pred, target, reduction=reduction)
    
    
    def xy_npinn_smoothness_loss(seq_full, coord_min=None, coord_max=None):
        """
        NPINN-inspired smoothness penalty on xy coordinates
        seq_full: [B, T, 2]
        """
        xy = seq_full[..., :2]
    
        if coord_min is not None and coord_max is not None:
            xy_norm = (xy - coord_min) / (coord_max - coord_min + 1e-8)
            xy_norm = 2 * (xy_norm - 0.5)  # [-1,1]
        else:
            xy_norm = xy
    
        v = xy_norm[:, 1:, :] - xy_norm[:, :-1, :]
        a = v[:, 1:, :] - v[:, :-1, :]
        return (v**2).mean() * 0.05 + (a**2).mean() * 0.5
    
    def train_model(model, loader, val_dl, optimizer, device, epochs=50,
                    smooth_w_init=1e-3, coord_min=None, coord_max=None):
    
        best_loss = float('inf')
        best_state = None
    
        for epoch in range(epochs):
            model.train()
            total_loss = total_data_loss = total_smooth_loss = 0.0
    
            for batch_x, batch_y in loader:
                batch_x = batch_x.to(device)  # [B, T_in, F]
                batch_y = batch_y.to(device)  # [B, T_out, 2] absolute x,y
    
                optimizer.zero_grad()
                pred_xy = model(batch_x, target_seq=batch_y, teacher_forcing_ratio=0.5)
    
                # Data loss: directly on absolute x,y
                loss_data = weighted_coord_loss(pred_xy, batch_y)
    
                # Smoothness loss: encourage smooth xy trajectories
                y_start = batch_x[:, :, :2]
                full_seq = torch.cat([y_start, pred_xy], dim=1)  # observed + predicted
                loss_smooth = xy_npinn_smoothness_loss(full_seq, coord_min, coord_max)
    
                smooth_weight = smooth_w_init * max(0.1, 1.0 - epoch / 30.0)
                loss = loss_data + smooth_weight * loss_smooth
    
                loss.backward()
                optimizer.step()
    
                total_loss += loss.item()
                total_data_loss += loss_data.item()
                total_smooth_loss += loss_smooth.item()
    
            avg_loss = total_loss / len(loader)
            print(f"Epoch {epoch+1} | Total: {avg_loss:.6f} | Data: {total_data_loss/len(loader):.6f} | Smooth: {total_smooth_loss/len(loader):.6f}")
    
            # Validation
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for xb, yb in val_dl:
                    xb, yb = xb.to(device), yb.to(device)
                    pred_xy = model(xb, target_seq=yb, teacher_forcing_ratio=0.0)
    
                    data_loss = weighted_coord_loss(pred_xy, yb)
                    full_seq = torch.cat([xb[..., :2], pred_xy], dim=1)
                    loss_smooth = xy_npinn_smoothness_loss(full_seq, coord_min, coord_max)
    
                    val_loss += (data_loss + smooth_weight * loss_smooth).item()
    
            val_loss /= len(val_dl)
            print(f"           Val Loss: {val_loss:.6f}")
    
            if val_loss < best_loss:
                best_loss = val_loss
                best_state = model.state_dict()
    
        if best_state is not None:
            torch.save(best_state, "best_model_NPINN.pth")
            print("Best model saved")
    import torch
    import numpy as np
    import random
    from torch import nn
    
    SEED = 42
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    train_X = data['train_X']
    input_steps = 80
    output_steps = 2
    
    # Collapse batch and time dims to compute global min/max for each feature
    flat_train_X = train_X.view(-1, train_X.shape[-1])  # shape: [N*T, F]
    
    x_min, x_max = flat_train_X[:, 0].min().item(), flat_train_X[:, 0].max().item()
    y_min, y_max = flat_train_X[:, 1].min().item(), flat_train_X[:, 1].max().item()
    cog_sin_min, cog_sin_max = flat_train_X[:, 2].min().item(), flat_train_X[:, 2].max().item()
    cog_cos_min, cog_cos_max = flat_train_X[:, 3].min().item(), flat_train_X[:, 3].max().item()
    sog_min, sog_max = flat_train_X[:, 4].min().item(), flat_train_X[:, 4].max().item()
    
    coord_min = torch.tensor([x_min, y_min], device=device)
    coord_max = torch.tensor([x_max, y_max], device=device)
    
    # Model setup
    input_size = 7 
    hidden_size = 64
    
    model = Seq2SeqLSTM(input_size, hidden_size, input_steps, output_steps).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # coord_min/max only for xy
    coord_min = torch.tensor([x_min, y_min], device=device)
    coord_max = torch.tensor([x_max, y_max], device=device)
    
    train_model(model, train_dl, val_dl, optimizer, device,
                coord_min=coord_min, coord_max=coord_max)
    import torch
    import joblib
    import json
    import pyproj
    import numpy as np
    from torch.utils.data import DataLoader, TensorDataset
    import cartopy.crs as ccrs
    import cartopy.feature as cfeature
    
    # --- device ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # recover arrays (assumes you already loaded `data`)
    train_X, train_Y = data['train_X'], data['train_Y']
    val_X, val_Y     = data['val_X'], data['val_Y']
    test_X, test_Y   = data['test_X'], data['test_Y']
    
    # --- load scalers (must match what you saved during preprocessing) ---
    feature_scaler = joblib.load("npinn_feature_scaler.pkl")
    sog_scaler = joblib.load("npinn_sog_scaler.pkl")
    delta_scaler   = joblib.load("npinn_delta_scaler.pkl")    # for dx,dy
    # if you named them differently, change the filenames above
    
    # --- load projection params and build proj ---
    with open("proj_params.json", "r") as f:
        proj_params = json.load(f)
    proj = pyproj.Proj(**proj_params)
    
    # --- rebuild & load model ---
    input_size   = train_X.shape[2]
    input_steps  = train_X.shape[1]
    output_steps = train_Y.shape[1]
    
    hidden_size = 64
    num_layers  = 2
    
    best_model = Seq2SeqLSTM(
        input_size=input_size,
        hidden_size=hidden_size,
        input_steps=input_steps,
        output_steps=output_steps,
    ).to(device)
    
    best_model.load_state_dict(torch.load("best_model_NPINN_absXY.pth", map_location=device))
    best_model.eval()
    
    # ---------------- helper inverse/scaling utilities ----------------
    def inverse_dxdy_np(dxdy_scaled, scaler):
        """
        Invert scaled residuals (dx, dy) back to meters.
        dxdy_scaled: (..., 2) scaled
        scaler: RobustScaler/StandardScaler/MinMaxScaler fitted on residuals
        """
        dxdy_scaled = np.asarray(dxdy_scaled, dtype=float)
        if dxdy_scaled.ndim == 1:
            dxdy_scaled = dxdy_scaled[None, :]
    
        n_samples = dxdy_scaled.shape[0]
        n_features = scaler.scale_.shape[0] if hasattr(scaler, "scale_") else scaler.center_.shape[0]
    
        full_scaled = np.zeros((n_samples, n_features))
        full_scaled[:, :2] = dxdy_scaled
    
        if hasattr(scaler, "mean_"):
            center = scaler.mean_
            scale = scaler.scale_
        elif hasattr(scaler, "center_"):
            center = scaler.center_
            scale = scaler.scale_
        else:
            raise ValueError(f"Scaler type {type(scaler)} not supported")
    
        full = full_scaled * scale + center
        return full[:, :2] if dxdy_scaled.shape[0] > 1 else full[0, :2]
    
    # ---------------- compute residual std (meters) correctly ----------------
    # train_Y contains scaled residuals (dx,dy) per your preprocessing.
    all_resids_scaled = train_Y.reshape(-1, 2)                      # [sum_T, 2]
    all_resids_m = inverse_dxdy_np(all_resids_scaled, delta_scaler) # meters
    residual_std = np.std(all_resids_m, axis=0)
    print("Computed residual_std (meters):", residual_std)
    
    
    # ---------------- geometry helpers ----------------
    def haversine(lon1, lat1, lon2, lat2):
        """Distance (m) between lon/lat points using haversine formula; handles arrays."""
        R = 6371000.0
        lon1 = np.asarray(lon1, dtype=float)
        lat1 = np.asarray(lat1, dtype=float)
        lon2 = np.asarray(lon2, dtype=float)
        lat2 = np.asarray(lat2, dtype=float)
        # if any entry is non-finite, result will be nan — we'll guard upstream
        lon1, lat1, lon2, lat2 = map(np.radians, [lon1, lat1, lon2, lat2])
        dlon = lon2 - lon1
        dlat = lat2 - lat1
        a = np.sin(dlat/2.0)**2 + np.cos(lat1)*np.cos(lat2)*np.sin(dlon/2.0)**2
        # numerical stability: clip inside sqrt
        a = np.clip(a, 0.0, 1.0)
        return 2 * R * np.arcsin(np.sqrt(a))
    
    def trajectory_length(lons, lats):
        lons = np.asarray(lons, dtype=float)
        lats = np.asarray(lats, dtype=float)
        if lons.size < 2:
            return 0.0
        # guard non-finite
        if not (np.isfinite(lons).all() and np.isfinite(lats).all()):
            return float("nan")
        return np.sum(haversine(lons[:-1], lats[:-1], lons[1:], lats[1:]))
    def evaluate_with_errors(
        model,
        test_dl,
        proj,
        feature_scaler,
        delta_scaler,
        device,
        num_batches=None,   # None = use full dataset
        dup_tol: float = 1e-4,
        outputs_are_residual_xy: bool = True,
        residual_decode_mode: str = "cumsum",   # "cumsum", "independent", "stdonly"
        residual_std: np.ndarray = None,
        plot_map: bool = True   # <--- PLOT ALL TRAJECTORIES
    ):
        """
        Evaluate model trajectory predictions and report errors in meters.
        Optionally plot all trajectories on a map.
        """
        model.eval()
        errors_all = []
        length_diffs = []
        bad_count = 0
    
        # store all trajectories
        all_real = []
        all_pred = []
    
        with torch.no_grad():
            batches = 0
            for xb, yb in test_dl:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb, teacher_forcing_ratio=0.0)  # [B, T_out, F]
    
                # first sample of the batch
                input_seq = xb[0].cpu().numpy()
                real_seq  = yb[0].cpu().numpy()
                pred_seq  = pred[0].cpu().numpy()
    
                # Extract dx, dy residuals
                pred_resid_s = pred_seq[:, :2]
                real_resid_s = real_seq[:, :2]
    
                # Invert residuals to meters
                pred_resid_m = inverse_dxdy_np(pred_resid_s, delta_scaler)
                real_resid_m = inverse_dxdy_np(real_resid_s, delta_scaler)
    
                # Use last observed absolute position as starting point (meters)
                last_obs_xy_m = inverse_xy_only_np(input_seq[-1, :2], feature_scaler)
    
                # Reconstruct absolute positions
                if residual_decode_mode == "cumsum":
                    pred_xy_m = np.cumsum(pred_resid_m, axis=0) + last_obs_xy_m
                    real_xy_m = np.cumsum(real_resid_m, axis=0) + last_obs_xy_m
                elif residual_decode_mode == "independent":
                    pred_xy_m = pred_resid_m + last_obs_xy_m
                    real_xy_m = real_resid_m + last_obs_xy_m
                elif residual_decode_mode == "stdonly":
                    if residual_std is None:
                        raise ValueError("residual_std must be provided for 'stdonly' mode")
                    noise = np.random.randn(*pred_resid_m.shape) * residual_std
                    pred_xy_m = np.cumsum(noise, axis=0) + last_obs_xy_m
                    real_xy_m = np.cumsum(real_resid_m, axis=0) + last_obs_xy_m
                else:
                    raise ValueError(f"Unknown residual_decode_mode: {residual_decode_mode}")
    
                # Remove first target if duplicates last input
                if np.allclose(real_resid_m[0], 0, atol=dup_tol):
                    real_xy_m = real_xy_m[1:]
                    pred_xy_m = pred_xy_m[1:]
    
                # align horizon
                min_len = min(len(pred_xy_m), len(real_xy_m))
                if min_len == 0:
                    bad_count += 1
                    continue
    
                pred_xy_m = pred_xy_m[:min_len]
                real_xy_m = real_xy_m[:min_len]
    
                # project to lon/lat
                lon_real, lat_real = proj(real_xy_m[:,0], real_xy_m[:,1], inverse=True)
                lon_pred, lat_pred = proj(pred_xy_m[:,0], pred_xy_m[:,1], inverse=True)
    
                all_real.append((lon_real, lat_real))
                all_pred.append((lon_pred, lat_pred))
    
                # compute per-timestep errors
                errors = haversine(lon_real, lat_real, lon_pred, lat_pred)
                errors_all.append(errors)
    
                # trajectory length diff
                real_len = trajectory_length(lon_real, lat_real)
                pred_len = trajectory_length(lon_pred, lat_pred)
                length_diffs.append(abs(real_len - pred_len))
    
                print(f"Trajectory length (true): {real_len:.2f} m | pred: {pred_len:.2f} m | diff: {abs(real_len - pred_len):.2f} m")
    
                batches += 1
                if num_batches is not None and batches >= num_batches:
                    break
    
        # summary
        if len(errors_all) == 0:
            print("No valid samples evaluated. Bad count:", bad_count)
            return
    
        max_len = max(len(e) for e in errors_all)
        errors_padded = np.full((len(errors_all), max_len), np.nan)
        for i, e in enumerate(errors_all):
            errors_padded[i, :len(e)] = e
    
        mean_per_t = np.nanmean(errors_padded, axis=0)
        print("\n=== Summary (meters) ===")
        for t, v in enumerate(mean_per_t):
            if not np.isnan(v):
                print(f"t={t} mean error: {v:.2f} m")
        print(f"Mean over horizon: {np.nanmean(errors_padded):.2f} m | Median: {np.nanmedian(errors_padded):.2f} m")
        print(f"Mean trajectory length diff: {np.mean(length_diffs):.2f} m | Median: {np.median(length_diffs):.2f} m")
        print("Bad / skipped samples:", bad_count)
    
        # --- plot all trajectories ---
           # --- plot each trajectory separately ---
        if plot_map and len(all_real) > 0:
            import matplotlib.pyplot as plt
            import cartopy.crs as ccrs
            import cartopy.feature as cfeature
    
            for idx, ((lon_r, lat_r), (lon_p, lat_p)) in enumerate(zip(all_real, all_pred)):
                fig = plt.figure(figsize=(10, 8))
                ax = plt.axes(projection=ccrs.PlateCarree())
                ax.add_feature(cfeature.LAND)
                ax.add_feature(cfeature.COASTLINE)
                ax.add_feature(cfeature.BORDERS, linestyle=':')
                
                ax.plot(lon_r, lat_r, color='green', linewidth=2, label="True")
                ax.plot(lon_p, lat_p, color='red', linestyle='--', linewidth=2, label="Predicted")
                ax.legend()
    
                ax.set_title(f"Trajectory {idx+1}: True vs Predicted")
                plt.show()
    
    Trajectory length (true): 521.39 m | pred: 508.66 m | diff: 12.73 m
    Trajectory length (true): 188.76 m | pred: 206.01 m | diff: 17.25 m
    pip install langchain chromadb openai gradio beautifulsoup4 requests
    export GOOGLE_API_KEY="your-api-key-here"
    import getpass
    import os
    
    from sentence_transformers import SentenceTransformer
    from langchain.chat_models import init_chat_model
    
    if not os.environ.get("GOOGLE_API_KEY"):
      os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")
    
    model = init_chat_model("gemini-2.5-flash", model_provider="google_genai")
    embeddings_model = SentenceTransformer('all-MiniLM-L6-v2')
    import os
    import getpass
    
    from bs4 import BeautifulSoup
    import requests
    
    from langchain.chat_models import init_chat_model
    from langchain_community.vectorstores import Chroma
    from langchain.docstore.document import Document
    from langchain.chains import RetrievalQA
    from langchain.prompts import PromptTemplate
    
    from sentence_transformers import SentenceTransformer
    from langchain.embeddings import HuggingFaceEmbeddings
    
    
    # ----------------------
    # Setup API keys
    # ----------------------
    if not os.environ.get("GOOGLE_API_KEY"):
        os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")
    
    # ----------------------
    # Initialize models
    # ----------------------
    # Use HuggingFace wrapper so LangChain understands SentenceTransformer
    embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
    
    # Gemini model (via LangChain)
    llm = init_chat_model("gemini-2.5-flash", model_provider="google_genai")
    
    # ----------------------
    # Initialize Chroma vector store
    # ----------------------
    persist_directory = "./chroma_db"
    vectorstore = Chroma(
        collection_name="aisviz_docs",
        embedding_function=embeddings_model,
        persist_directory=persist_directory,
    )
    
    # ----------------------
    # Scraper (example)
    # ----------------------
    def scrape_aisviz_docs(base_url="https://aisviz.example.com/docs"):
        """
        Scrape AISViz docs and return list of LangChain Documents.
        Adjust selectors based on actual site structure.
        """
        response = requests.get(base_url)
        response.raise_for_status()
        soup = BeautifulSoup(response.text, "html.parser")
    
        docs = []
        for section in soup.find_all("div", class_="doc-section"):
            text = section.get_text(strip=True)
            docs.append(Document(page_content=text, metadata={"source": base_url}))
    
        return docs
    
    # Example: index docs once
    def build_index():
        docs = scrape_aisviz_docs()
        vectorstore.add_documents(docs)
        vectorstore.persist()
    
    # ----------------------
    # Retrieval QA
    # ----------------------
    QA_PROMPT = PromptTemplate(
        input_variables=["context", "question"],
        template="""
    You are an assistant for question-answering tasks about AISViz documentation. 
    Use the following context to answer the user’s question. 
    If the answer is not contained in the context, say you don't know.
    
    Context:
    {context}
    
    Question:
    {question}
    
    Answer concisely:
    """,
    )
    
    retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
    
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        retriever=retriever,
        chain_type="stuff",
        chain_type_kwargs={"prompt": QA_PROMPT},
        return_source_documents=True,
    )
    
    # ----------------------
    # Usage
    # ----------------------
    def answer_question(question: str):
        result = qa_chain({"query": question})
        return result["result"], result["source_documents"]
    
    
    if __name__ == "__main__":
        # Uncomment below line if first time indexing
        # build_index()
    
        ans, sources = answer_question("What is AISViz used for?")
        print("Answer:", ans)
        print("Sources:", [s.metadata["source"] for s in sources])
    
    import requests
    from bs4 import BeautifulSoup
    from urllib.parse import urljoin, urlparse
    import time
    
    def scrape_aisviz_docs():
        """Scrape all AISViz documentation pages"""
        
        base_urls = [
            "https://aisviz.gitbook.io/documentation",
            "https://aisviz.cs.dal.ca",
        ]
        
        scraped_content = []
        visited_urls = set()
        
        for base_url in base_urls:
            # Get the main page
            response = requests.get(base_url)
            soup = BeautifulSoup(response.content, 'html.parser')
            # Extract main content
            main_content = soup.find('div', class_='page-content') or soup.find('main')
            if main_content:
                text = main_content.get_text(strip=True)
                scraped_content.append({
                    'url': base_url,
                    'title': soup.find('title').get_text() if soup.find('title') else 'AISViz Documentation',
                    'content': text
                })
            
            # Find all links to other documentation pages
            links = soup.find_all('a', href=True)
            for link in links:
                href = link['href']
                full_url = urljoin(base_url, href)
                
                # Only follow links within the same domain
                if urlparse(full_url).netloc == urlparse(base_url).netloc and full_url not in visited_urls:
                    visited_urls.add(full_url)
                    try:
                        time.sleep(0.5)  # Be respectful to the server
                        page_response = requests.get(full_url)
                        page_soup = BeautifulSoup(page_response.content, 'html.parser')
                        
                        page_content = page_soup.find('div', class_='page-content') or page_soup.find('main')
                        if page_content:
                            text = page_content.get_text(strip=True)
                            scraped_content.append({
                                'url': full_url,
                                'title': page_soup.find('title').get_text() if page_soup.find('title') else 'AISViz Page',
                                'content': text
                            })
                    except Exception as e:
                        print(f"Error scraping {full_url}: {e}")
                        continue
    
        return scraped_content
    
    # Scrape all documentation
    docs_data = scrape_aisviz_docs()
    print(f"Scraped {len(docs_data)} pages from AISViz documentation")
    def split_text(text, chunk_size=1000, chunk_overlap=200):
        """Split text into overlapping chunks"""
        chunks = []
        start = 0
        
        while start < len(text):
            # Find the end of this chunk
            end = start + chunk_size
            
            # If this isn't the last chunk, try to break at a sentence or word boundary
            if end < len(text):
                # Look for sentence boundary
                last_period = text.rfind('.', start, end)
                last_newline = text.rfind('\n', start, end)
                last_space = text.rfind(' ', start, end)
                
                # Use the best boundary we can find
                if last_period > start + chunk_size // 2:
                    end = last_period + 1
                elif last_newline > start + chunk_size // 2:
                    end = last_newline
                elif last_space > start + chunk_size // 2:
                    end = last_space
            
            chunk = text[start:end].strip()
            if chunk:
                chunks.append(chunk)
            # Move start position for next chunk (with overlap)
            start = end - chunk_overlap if end < len(text) else end
        
        return chunks
    
    # Split all documents into chunks
    all_chunks = []
    chunk_metadata = []
    
    for doc_data in docs_data:
        chunks = split_text(doc_data['content'])
        for i, chunk in enumerate(chunks):
            all_chunks.append(chunk)
            chunk_metadata.append({
                'source': doc_data['url'],
                'title': doc_data['title'],
                'chunk_id': i
            })
    
    print(f"Split documentation into {len(all_chunks)} chunks")
    from sentence_transformers import SentenceTransformer
    import chromadb
    
    # Initialize SentenceTransformer model
    embeddings_model = SentenceTransformer('all-MiniLM-L6-v2')
    
    # Create Chroma client and collection
    chroma_client = chromadb.PersistentClient(path="./chroma_db")
    collection = chroma_client.get_or_create_collection(name="aisviz_docs")
    
    # Create embeddings for all chunks (this may take a few minutes)
    print("Creating embeddings for all chunks...")
    chunk_embeddings = embeddings_model.encode(all_chunks, show_progress_bar=True)
    # Prepare data for Chroma
    chunk_ids = [f"chunk_{i}" for i in range(len(all_chunks))]
    metadatas = chunk_metadata
    
    # Add everything to Chroma collection
    collection.add(
        embeddings=chunk_embeddings.tolist(),
        documents=all_chunks,
        metadatas=metadatas,
        ids=chunk_ids
    )
    
    print("Documents indexed and stored in Chroma database")
    import getpass
    import os
    
    # ----------------------
    # 1. API Key Setup
    # ----------------------
    # Check if GOOGLE_API_KEY is already set in the environment.
    # If not, securely prompt the user to enter it (won’t show in terminal).
    if not os.environ.get("GOOGLE_API_KEY"):
        os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")
    
    from langchain.chat_models import init_chat_model
    
    # ----------------------
    # 2. Initialize Gemini Model (via LangChain)
    # ----------------------
    # We create a chat model wrapper around Gemini.
    # "gemini-2.5-flash" is a lightweight, fast generation model.
    # LangChain handles the API calls under the hood.
    model = init_chat_model("gemini-2.5-flash", model_provider="google_genai")
    
    
    # ----------------------
    # 3. Define RAG Function
    # ----------------------
    def answer_question(question, k=4):
        """
        Answer a question using Retrieval-Augmented Generation (RAG)
        over AISViz documentation stored in ChromaDB.
        
        Args:
            question (str): User question.
            k (int): Number of top documents to retrieve.
        Returns:
            dict with keys: 'answer', 'sources', 'context'.
        """
    
        # Step 1: Create an embedding for the question
        # Encode the question into a dense vector using SentenceTransformers.
        # This embedding will be used to search for semantically similar docs.
        question_embedding = embeddings_model.encode([question])
    
        # Step 2: Retrieve relevant documents from Chroma
        # Query the Chroma vector store for the top-k most relevant docs.
        results = collection.query(
            query_embeddings=question_embedding.tolist(),
            n_results=k
        )
    
        # Step 3: Build context from retrieved documents
        # Extract both content and metadata for each retrieved doc.
        retrieved_docs = results['documents'][0]
        retrieved_metadata = results['metadatas'][0]
    
        # Join all documents into one context string.
        # Each doc is prefixed with its source/title for attribution.
        context = "\n\n".join([
            f"Source: {meta.get('title', 'AISViz Documentation')}\n{doc}"
            for doc, meta in zip(retrieved_docs, retrieved_metadata)
        ])
    
        # Step 4: Construct prompts
        # System prompt: instructs Gemini to behave like a doc-based assistant.
        system_prompt = """You are an assistant for question-answering tasks about AISViz documentation and maritime vessel tracking. 
        Use the following pieces of retrieved context to answer the question. 
        If you don't know the answer based on the context, just say that you don't know. 
        Keep the answer concise and helpful.
        Always mention which sources you're referencing when possible."""
    
        # User prompt: combines the retrieved context with the actual user question.
        user_prompt = f"""Context:
    {context}
    
    Question: {question}
    
    Answer:"""
    
        # Step 5: Generate answer using Gemini
        # Combine system + user prompts and send to Gemini.
        full_prompt = f"{system_prompt}\n\n{user_prompt}"
    
        try:
            response = gemini_model.generate_content(full_prompt)
            answer = response.text  # Extract plain text from Gemini response
        except Exception as e:
            # If Gemini API call fails, catch error and return message
            answer = f"Sorry, I encountered an error generating the response: {str(e)}"
    
        # Return a structured result: answer text, sources, and raw context.
        return {
            'answer': answer,
            'sources': [meta.get('source', '') for meta in retrieved_metadata],
            'context': context
        }
    
    
    # ----------------------
    # 4. Test the Function
    # ----------------------
    result = answer_question("What is AISViz?")
    print("Answer:", result['answer'])
    print("Sources:", result['sources'])
    
    import gradio as gr
    
    def chatbot_interface(message, history):
        """Interface function for Gradio chatbot"""
        try:
            result = answer_question(message)
            response = result['answer']
            
            # Add source information to the response
            if result['sources']:
                unique_sources = list(set(result['sources'][:3]))
                sources_text = "\n\n**Sources:**\n" + "\n".join([f"- {source}" for source in unique_sources])
                response += sources_text
                
            return response
        except Exception as e:
            return f"Sorry, I encountered an error: {str(e)}"
    
    # Create Gradio interface
    demo = gr.ChatInterface(
        fn=chatbot_interface,
        title="🚢 AISViz Documentation Chatbot",
        description="Ask questions about AISViz, AISdb, and maritime vessel tracking!",
        examples=[
            "What is AISViz?",
            "How do I get started with AISdb?",
            "What kind of data does AISViz work with?",
            "How can I analyze vessel trajectories?"
        ],
        retry_btn=None,
        undo_btn="Delete Previous",
        clear_btn="Clear History",
    )
    
    if __name__ == "__main__":
        demo.launch()
    import os
    import h3
    import json
    import aisdb
    import cartopy.feature as cfeature
    import cartopy.crs as ccrs
    from aisdb.database.dbconn import PostgresDBConn
    from aisdb.denoising_encoder import encode_greatcircledistance, InlandDenoising
    from aisdb.track_gen import min_speed_filter, min_track_length_filter
    from aisdb.database import sqlfcn
    from datetime import datetime, timedelta
    from collections import defaultdict
    from tqdm import tqdm
    import pprint
    import numpy as np
    import geopandas as gpd
    import matplotlib.pyplot as plt
    
    import nest_asyncio
    nest_asyncio.apply()
    def process_interval(dbconn, start, end):
        # Open a new connection with the database
        qry = aisdb.DBQuery(dbconn=dbconn, start=start, end=end,
                            xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax,
                            callback=aisdb.database.sqlfcn_callbacks.in_bbox_time_validmmsi)
        
        # Decimate is for removing unnecessary points in the trajectory
        rowgen = qry.gen_qry(fcn=sqlfcn.crawl_dynamic_static)
        tracks = aisdb.track_gen.TrackGen(rowgen, decimate=False)
    
        with InlandDenoising(data_dir='./data/tmp/') as remover:
            cleaned_tracks = remover.filter_noisy_points(tracks)
        
        # Split the tracks based on the time between transmissions
        track_segments = aisdb.track_gen.split_timedelta(cleaned_tracks, time_split)
        
        # Filter out segments that are below the min-score threshold
        tracks_encoded = encode_greatcircledistance(track_segments, distance_threshold=distance_split, speed_threshold=speed_split)
        tracks_encoded = min_speed_filter(tracks_encoded, minspeed=1)
    
        # Interpolate the segments every five minutes to enforce continuity
        tracks_interpolated = aisdb.interp.interp_time(tracks_encoded, step=timedelta(minutes=1))
    
        # Returns a generator from a Python variable
        return list(tracks_interpolated)
    # Load the shapefile
    gulf_shapefile = './data/region/gulf.shp'
    
    print(f"Loading shapefile from {gulf_shapefile}...")
    gdf_gulf = gpd.read_file(gulf_shapefile)
    
    gdf_hexagons = gpd.read_file('./data/cell/Hexagons_6.shp')
    gdf_hexagons = gdf_hexagons.to_crs(epsg=4326)  # Consistent CRS projection
    # valid_h3_ids = set(gdf_hexagons['hex_id'])
    bounding_box = gdf_hexagons.total_bounds  # Extract the bounding box
    # bounding_box = gdf_gulf.total_bounds  # Extract the bounding box
    xmin, ymin, xmax, ymax = bounding_box  # Split the bounding box
    
    start_date = datetime(2023, 1, 1)
    end_date = datetime(2023, 1, 30)
    print(f"Processing trajectories from {start_date} to {end_date}")
    # Define pre-processing parameters
    time_split = timedelta(hours=3)
    distance_split = 10000  # meters
    speed_split = 40  # knots
    
    cell_visits = defaultdict(lambda: defaultdict(list))
    filtered_visits = defaultdict(lambda: defaultdict(list))
    g2h3_vec = np.vectorize(h3.latlng_to_cell)
    pp = pprint.PrettyPrinter(indent=4)
    
    track_info_list = []
    track_list = process_interval(dbconn, start_date, end_date)
    for track in tqdm(track_list, total=len(track_list), desc="Vessels", leave=False):
        h3_ids = g2h3_vec(track['lat'], track['lon'], 6)
        timestamps = track['time']
        # Identify the entry points of cells on a track
        # Deduplicate consecutive identical h3_ids while preserving the entry timestamp
        dedup_h3_ids = [h3_ids[0]]
        dedup_timestamps = [timestamps[0]]
        for i in range(1, len(h3_ids)):
            if h3_ids[i] != dedup_h3_ids[-1]:
                dedup_h3_ids.append(h3_ids[i])
                dedup_timestamps.append(timestamps[i])
        track_info = {
            "mmsi": track['mmsi'],
            "h3_seq": dedup_h3_ids,
            "timestamp_seq": dedup_timestamps
        }
        track_info_list.append(track_info)
    import seaborn as sns
    
    def plot_length_distribution(track_lengths):
        # Compute summary stats
        length_stats = {
            "min": int(np.min(track_lengths)),
            "max": int(np.max(track_lengths)),
            "mean": float(np.mean(track_lengths)),
            "median": float(np.median(track_lengths)),
            "percentiles": {
                "10%": int(np.percentile(track_lengths, 10)),
                "25%": int(np.percentile(track_lengths, 25)),
                "50%": int(np.percentile(track_lengths, 50)),
                "75%": int(np.percentile(track_lengths, 75)),
                "90%": int(np.percentile(track_lengths, 90)),
                "95%": int(np.percentile(track_lengths, 95)),
            }
        }
        print(length_stats)
    
        # Plot distribution
        plt.figure(figsize=(10, 6))
        sns.histplot(track_lengths, bins=100, kde=True)
        plt.title("Distribution of Track Lengths")
        plt.xlabel("Track Length (number of H3 cells)")
        plt.ylabel("Frequency")
        plt.grid(True)
        plt.tight_layout()
        plt.show()
    
    
    def map_view(tracks, dot_size=3, color=None, save=False, path=None, bbox=None, line=False, line_width=0.5, line_opacity=0.3):
        fig = plt.figure(figsize=(16, 9))
        ax = plt.axes(projection=ccrs.PlateCarree())
    
        # Add cartographic features
        ax.add_feature(cfeature.OCEAN.with_scale('10m'), facecolor='#E0E0E0')
        ax.add_feature(cfeature.LAND.with_scale('10m'), facecolor='#FFE5CC')
        ax.add_feature(cfeature.BORDERS, linestyle=':')
        ax.add_feature(cfeature.LAKES, alpha=0.5)
        ax.add_feature(cfeature.RIVERS)
        ax.coastlines(resolution='10m')
    
        if line:
            for track in tqdm(tracks):
                ax.plot(track['lon'], track['lat'], color=color, linewidth=line_width, alpha=line_opacity, transform=ccrs.PlateCarree())
        else:
            for track in tqdm(tracks):
                ax.scatter(track['lon'], track['lat'], c=color, s=dot_size, transform=ccrs.PlateCarree())
    
        if bbox:
            # Set the map extent based on a bounding box
            ax.set_extent(bbox, crs=ccrs.PlateCarree())
    
        ax.gridlines(draw_labels=True)
    
        if save:
            plt.savefig(path, dpi=300, transparent=True)
    
        plt.show()
    
    
    def hex_view(lats, lons, save=True):
        plt.figure(figsize=(8,8))
        for traj_lat, traj_lon in zip(lats, lons):
            plt.plot(traj_lon, traj_lat, alpha=0.3, linewidth=1)
    
        plt.xlabel("Longitude")
        plt.ylabel("Latitude")
        plt.title("Test Trajectories")
        plt.axis("equal")
        if save:
            plt.savefig("img/test_track.png", dpi=300)
    track_info_list = [t for t in track_info_list if (len(t['h3_seq']) >= 10)&(len(t['h3_seq']) <= 300)]
    vec_cell_to_latlng = np.vectorize(h3.cell_to_latlng)
    
    # Extract hex ids from all tracks
    all_h3_ids = set()
    for track in track_info_list:
        all_h3_ids.update(track['h3_seq'])  # or t['int_seq'] if already mapped
    
    # Build vocab: reserve 0,1,2 for BOS, EOS, PAD
    h3_vocab = {h: i+3 for i, h in enumerate(sorted(all_h3_ids))}
    # h3_vocab = {h: i+3 for i, h in enumerate(sorted(valid_h3_ids))}
    special_tokens = {"<PAD>": 0, "<BOS>": 1, "<EOS>": 2}
    h3_vocab.update(special_tokens)
    
    for t in track_info_list:
        t["int_seq"] = [h3_vocab[h] for h in t["h3_seq"] if h in h3_vocab]
        t["lat"], t["lon"] = vec_cell_to_latlng(t.get('h3_seq'))
    
    from sklearn.model_selection import train_test_split
    
    # Initial split: train vs temp
    train_tracks, temp_tracks = train_test_split(track_info_list, test_size=0.4, random_state=42)
    # Second split: validation vs test
    val_tracks, test_tracks = train_test_split(temp_tracks, test_size=0.5, random_state=42)
    
    def save_data(tracks, prefix, output_dir="data"):
        os.makedirs(output_dir, exist_ok=True)
        with open(os.path.join(output_dir, f"{prefix}.src"), "w") as f_src, \
             open(os.path.join(output_dir, f"{prefix}.trg"), "w") as f_trg, \
             open(os.path.join(output_dir, f"{prefix}.lat"), "w") as f_lat, \
             open(os.path.join(output_dir, f"{prefix}.lon"), "w") as f_lon, \
             open(os.path.join(output_dir, f"{prefix}_trj.t"), "w") as f_t:
    
            for idx, t in enumerate(tracks):
                ids = t["int_seq"]
                src = ids[:-1]
                trg = ids[1:]
                
                f_t.write(" ".join(map(str, ids)) + "\n")   # the whole track, t = src U trg
                f_src.write(" ".join(map(str, src)) + "\n")
                f_trg.write(" ".join(map(str, trg)) + "\n")
                f_lat.write(" ".join(map(str, t.get('lat'))) + "\n")
                f_lon.write(" ".join(map(str, t.get('lon'))) + "\n")
    
    def save_h3_vocab(h3_vocab, output_dir="data", filename="vocab.json"):
        os.makedirs(output_dir, exist_ok=True)
        with open(os.path.join(output_dir, filename), "w") as f:
            json.dump(h3_vocab, f, indent=2)
            
            
    save_data(train_tracks, "train")
    save_data(val_tracks, "val")
    save_data(test_tracks, "test")
    save_h3_vocab(h3_vocab) # save the INT index mapping to H3 index
    
    lats = [np.fromstring(line, sep=' ') for line in open("data/train.lat")]
    lons = [np.fromstring(line, sep=' ') for line in open("data/train.lon")]
    
    hex_view(lats, lons)
    import os
    import numpy as np
    
    import torch
    import torch.nn as nn
    from torch.nn.utils import clip_grad_norm_
    # from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
    from torch.utils.tensorboard import SummaryWriter
    
    # from funcy import merge
    import time, os, shutil, logging, h5py
    # from collections import namedtuple
    
    from model.t2vec import EncoderDecoder
    from data_loader import DataLoader
    from utils import *
    from model.loss import *
    
    writer = SummaryWriter()
    
    PAD = 0
    BOS = 1
    EOS = 2
    def init_parameters(model):
        for p in model.parameters():
            p.data.uniform_(-0.1, 0.1)
    
    def savecheckpoint(state, is_best, args):
        torch.save(state, args.checkpoint)
        if is_best:
            shutil.copyfile(args.checkpoint, os.path.join(args.data, 'best_model.pt'))
    
    def validate(valData, model, lossF, args):
        """
        valData (DataLoader)
        """
        m0, m1 = model
        ## switch to evaluation mode
        m0.eval()
        m1.eval()
    
        num_iteration = valData.size // args.batch
        if valData.size % args.batch > 0: num_iteration += 1
    
        total_genloss = 0
        for iteration in range(num_iteration):
            gendata = valData.getbatch_generative()
            with torch.no_grad():
                genloss = genLoss(gendata, m0, m1, lossF, args)
                total_genloss += genloss.item() * gendata.trg.size(1)
        ## switch back to training mode
        m0.train()
        m1.train()
        return total_genloss / valData.size
    
    
    def train(args):
        logging.basicConfig(filename=os.path.join(args.data, "training.log"), level=logging.INFO)
    
        trainsrc = os.path.join(args.data, "train.src") # data path
        traintrg = os.path.join(args.data, "train.trg")
        trainlat = os.path.join(args.data, "train.lat")
        trainlon = os.path.join(args.data, "train.lon")
        # trainmta = os.path.join(args.data, "train.mta")
        trainData = DataLoader(trainsrc, traintrg, trainlat, trainlon, args.batch, args.bucketsize)
        print("Reading training data...")
        trainData.load(args.max_num_line)
        print("Allocation: {}".format(trainData.allocation))
        print("Percent: {}".format(trainData.p))
    
        valsrc = os.path.join(args.data, "val.src")
        valtrg = os.path.join(args.data, "val.trg")
        vallat = os.path.join(args.data, "val.lat")
        vallon = os.path.join(args.data, "val.lon")
        if os.path.isfile(valsrc) and os.path.isfile(valtrg):
            valData = DataLoader(valsrc, valtrg, vallat, vallon, args.batch, args.bucketsize, validate=True)
            print("Reading validation data...")
            valData.load()
            assert valData.size > 0, "Validation data size must be greater than 0"
            print("Loaded validation data size {}".format(valData.size))
        else:
            print("No validation data found, training without validating...")
    
        ## create criterion, model, optimizer
        if args.criterion_name == "NLL":
            criterion = NLLcriterion(args.vocab_size)
            lossF = lambda o, t: criterion(o, t)
        else:
            assert os.path.isfile(args.knearestvocabs),\
                "{} does not exist".format(args.knearestvocabs)
            print("Loading vocab distance file {}...".format(args.knearestvocabs))
            with h5py.File(args.knearestvocabs, "r") as f:
                V, D = f["V"][...], f["D"][...]
                V, D = torch.LongTensor(V), torch.FloatTensor(D)
            D = dist2weight(D, args.dist_decay_speed)
            if args.cuda and torch.cuda.is_available():
                V, D = V.cuda(), D.cuda()
            criterion = KLDIVcriterion(args.vocab_size)
            lossF = lambda o, t: KLDIVloss(o, t, criterion, V, D)
        triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
    
        m0 = EncoderDecoder(args.vocab_size,
                            args.embedding_size,
                            args.hidden_size,
                            args.num_layers,
                            args.dropout,
                            args.bidirectional)
        m1 = nn.Sequential(nn.Linear(args.hidden_size, args.vocab_size),
                           nn.LogSoftmax(dim=1))
        if args.cuda and torch.cuda.is_available():
            print("=> training with GPU")
            m0.cuda()
            m1.cuda()
            criterion.cuda()
            #m0 = nn.DataParallel(m0, dim=1)
        else:
            print("=> training with CPU")
    
        m0_optimizer = torch.optim.Adam(m0.parameters(), lr=args.learning_rate)
        m1_optimizer = torch.optim.Adam(m1.parameters(), lr=args.learning_rate)
    
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            m0_optimizer,
            mode='min',
            patience=args.lr_decay_patience,
            min_lr=0,
            verbose=True
        )
    
        ## load model state and optmizer state
        if os.path.isfile(args.checkpoint):
            print("=> loading checkpoint '{}'".format(args.checkpoint))
            logging.info("Restore training @ {}".format(time.ctime()))
            checkpoint = torch.load(args.checkpoint)
            args.start_iteration = checkpoint["iteration"]
            best_prec_loss = checkpoint["best_prec_loss"]
            m0.load_state_dict(checkpoint["m0"])
            m1.load_state_dict(checkpoint["m1"])
            m0_optimizer.load_state_dict(checkpoint["m0_optimizer"])
            m1_optimizer.load_state_dict(checkpoint["m1_optimizer"])
        else:
            print("=> no checkpoint found at '{}'".format(args.checkpoint))
            logging.info("Start training @ {}".format(time.ctime()))
            best_prec_loss = float('inf')
            args.start_iteration = 0
            print("=> initializing the parameters...")
            init_parameters(m0)
            init_parameters(m1)
            ## here: load pretrained wrod (cell) embedding
    
        num_iteration = 6700*128 // args.batch
        print("Iteration starts at {} "
              "and will end at {}".format(args.start_iteration, num_iteration-1))
        
        no_improvement_count = 0
    
        ## training
        for iteration in range(args.start_iteration, num_iteration):
            try:
                m0_optimizer.zero_grad()
                m1_optimizer.zero_grad()
                ## generative loss
                gendata = trainData.getbatch_generative()
                genloss = genLoss(gendata, m0, m1, lossF, args)
                ## discriminative loss
                disloss_cross, disloss_inner = 0, 0
                if args.use_discriminative and iteration % 5 == 0:
                    a, p, n = trainData.getbatch_discriminative_cross()
                    disloss_cross = disLoss(a, p, n, m0, triplet_loss, args)
                    a, p, n = trainData.getbatch_discriminative_inner()
                    disloss_inner = disLoss(a, p, n, m0, triplet_loss, args)
                loss = genloss + args.discriminative_w * (disloss_cross + disloss_inner)
    
                # Add to tensorboard
                writer.add_scalar('Loss/train', loss, iteration)
    
                ## compute the gradients
                loss.backward()
                ## clip the gradients
                clip_grad_norm_(m0.parameters(), args.max_grad_norm)
                clip_grad_norm_(m1.parameters(), args.max_grad_norm)
                ## one step optimization
                m0_optimizer.step()
                m1_optimizer.step()
                ## average loss for one word
                avg_genloss = genloss.item() / gendata.trg.size(0)
    
                if iteration % args.print_freq == 0:
                    print("Iteration: {0:}\tGenerative Loss: {1:.3f}\t"\
                          "Discriminative Cross Loss: {2:.3f}\tDiscriminative Inner Loss: {3:.3f}"\
                          .format(iteration, avg_genloss, disloss_cross, disloss_inner))
                
                if iteration % args.save_freq == 0 and iteration > 0:
                    prec_loss = validate(valData, (m0, m1), lossF, args)
    
                    # Add to tensorboard
                    writer.add_scalar('Loss/validation', prec_loss, iteration)
    
                    scheduler.step(prec_loss)
                    if prec_loss < best_prec_loss:
                        best_prec_loss = prec_loss
                        logging.info("Best model with loss {} at iteration {} @ {}"\
                                     .format(best_prec_loss, iteration, time.ctime()))
                        is_best = True
                        no_improvement_count = 0
                    else:
                        is_best = False
                        no_improvement_count += 1
    
                    print("Saving the model at iteration {} validation loss {}"\
                          .format(iteration, prec_loss))
                    savecheckpoint({
                        "iteration": iteration,
                        "best_prec_loss": best_prec_loss,
                        "m0": m0.state_dict(),
                        "m1": m1.state_dict(),
                        "m0_optimizer": m0_optimizer.state_dict(),
                        "m1_optimizer": m1_optimizer.state_dict()
                    }, is_best, args)
    
                    # Early stopping if there is no improvement after a certain number of epochs
                    if no_improvement_count >= args.early_stopping_patience:
                        print('No improvement after {} iterations, early stopping triggered.'.format(args.early_stopping_patience))
                        break
                    
            except KeyboardInterrupt:
                break
    
    
    def test(args):
        # load testing data
        testsrc = os.path.join(args.data, "test.src")
        testtrg = os.path.join(args.data, "test.trg")
        testlat = os.path.join(args.data, "test.lat")
        testlon = os.path.join(args.data, "test.lon")
        if os.path.isfile(testsrc) and os.path.isfile(testtrg):
            testData = DataLoader(testsrc, testtrg, testlat, testlon, args.batch, args.bucketsize, validate=True)
            print("Reading testing data...")
            testData.load()
            assert testData.size > 0, "Testing data size must be greater than 0"
            print("Loaded testing data size {}".format(testData.size))
        else:
            print("No testing data found, aborting test.")
            return
    
        # set up model
        m0 = EncoderDecoder(args.vocab_size,
                            args.embedding_size,
                            args.hidden_size,
                            args.num_layers,
                            args.dropout,
                            args.bidirectional)
        m1 = nn.Sequential(nn.Linear(args.hidden_size, args.vocab_size),
                           nn.LogSoftmax(dim=1))
    
        # load best model state
        best_model_path = 'data/best_model.pt'
        if os.path.isfile(best_model_path):
            print("=> loading checkpoint '{}'".format(args.checkpoint))
            best_model = torch.load(best_model_path)
            m0.load_state_dict(best_model["m0"])
            m1.load_state_dict(best_model["m1"])
        else:
            print("Best model not found. Aborting test.")
            return
    
        m0.eval()
        m1.eval()
    
        # loss function
        criterion = NLLcriterion(args.vocab_size)
        lossF = lambda o, t: criterion(o, t)
    
        # check device
        if args.cuda and torch.cuda.is_available():
            print("=> test with GPU")
            m0.cuda()
            m1.cuda()
            criterion.cuda()
            #m0 = nn.DataParallel(m0, dim=1)
        else:
            print("=> test with CPU")
    
        num_iteration = (testData.size + args.batch - 1) // args.batch
        total_genloss = 0
        total_tokens = 0
    
        with torch.no_grad():
            for iter in range(num_iteration):
                gendata = testData.getbatch_generative()
                genloss = genLoss(gendata, m0, m1, lossF, args)
                total_genloss += genloss.item() #* gendata.trg.size(1) # remove the multiplication
                total_tokens += (gendata.trg != PAD).sum().item() # count non-pad tokens
                print("Testing genloss at {} iteration is {}".format(iter, total_genloss))
            
        avg_loss = total_genloss / total_tokens
        perplexity = torch.exp(torch.tensor(avg_loss))
    
        print(f"[Test] Avg Loss: {avg_loss:.4f} | Perplexity: {perplexity:.2f}")
    
    class Args:
        data = 'data/'
        checkpoint = 'data/checkpoint.pt'
        vocab_size = len(h3_vocab)
        embedding_size = 128
        hidden_size = 128
        num_layers = 1
        dropout = 0.1
        max_grad_norm = 1.0
        learning_rate = 1e-2
        lr_decay_patience = 20
        early_stopping_patience = 50
        cuda = torch.cuda.is_available()
        bidirectional = True
        batch = 16
        num_epochs = 100
        bucketsize = [(20,30),(30,30),(30,50),(50,50),(50,70),(70,70),(70,100),(100,100)]
        criterion_name = "NLL"
        use_discriminative = True
        discriminative_w = 0.1
        max_num_line = 200000
        start_iteration = 0
        generator_batch = 16
        print_freq = 10
        save_freq = 10
    
    
    args = Args()
    train(args)
    test(args)
    Testing genloss at 0 iteration is 46.40993881225586
    Testing genloss at 1 iteration is 83.17555618286133
    Testing genloss at 2 iteration is 122.76013565063477
    Testing genloss at 3 iteration is 167.81907272338867
    Testing genloss at 4 iteration is 223.75146102905273
    Testing genloss at 5 iteration is 287.765926361084
    Testing genloss at 6 iteration is 328.6252250671387
    Testing genloss at 7 iteration is 394.95031356811523
    Testing genloss at 8 iteration is 459.411678314209
    Testing genloss at 9 iteration is 557.8198432922363
    Testing genloss at 10 iteration is 724.5464973449707
    Testing genloss at 11 iteration is 876.1395149230957
    Testing genloss at 12 iteration is 1020.6461372375488
    Testing genloss at 13 iteration is 1277.3499336242676
    Testing genloss at 14 iteration is 1416.0101203918457
    Testing genloss at 15 iteration is 1742.3399543762207
    Testing genloss at 16 iteration is 2101.4984016418457
    Testing genloss at 17 iteration is 2319.603458404541
    [Test] Avg Loss: 0.2309 | Perplexity: 1.26
    Difference: 17.25 m

    Here, the predicted trajectory slightly overestimates the total distance (~9%), again with a smooth but slightly offset path relative to the true trajectory. The model captures the overall direction but has some scaling error in step lengths or residuals.

    https://huggingface.co/spaces/mapslab/AISVIZ-BOT/tree/mainhuggingface.cochevron-right