function [X,Y,T,M,Minit] = ShallowWaterModel(init_fcn, t_final, n_res, x_domain, y_domain)

global gravity omega phi0

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% ONLY MODIFY THESE PARAMETERS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% Constant of gravity
gravity = 9.80616;

% Rotation rate
omega = 7.292e-5;

% Latitude
phi0 = pi/4;

% CFL number
cfl = 0.5;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% Coriolis parameter
f = 2 * omega * sin(phi0);

% Halo size
halo = 2;

% Grid spacing
dx = (x_domain(2) - x_domain(1)) / n_res;
dy = (y_domain(2) - y_domain(1)) / n_res;

% Initialize the grid
XEdge = [0:n_res+1] * dx + x_domain(1);
YEdge = [0:n_res+1] * dy + y_domain(1);

XCentroid = 0.5 * (XEdge(1:n_res) + XEdge(1:n_res));
YCentroid = 0.5 * (YEdge(1:n_res) + YEdge(1:n_res));

X = XCentroid;
Y = YCentroid;

% Initialize the variables
h = zeros(n_res + 2*halo);
hu = zeros(n_res + 2*halo);
hv = zeros(n_res + 2*halo);

for i = 1:n_res
for j = 1:n_res
    io = i + halo;
    jo = j + halo;
    
    [h(io,jo) hu(io,jo) hv(io,jo)] = init_fcn(XCentroid(i), YCentroid(j));
    
    hu(io,jo) = hu(io,jo) * h(io,jo);
    hv(io,jo) = hv(io,jo) * h(io,jo);
end
end

% Store initial state
Minit = zeros(n_res, n_res, 3);
Minit(:,:,1) = h(halo+1:n_res+halo, halo+1:n_res+halo);
Minit(:,:,2) = hu(halo+1:n_res+halo, halo+1:n_res+halo) ./ Minit(:,:,1);
Minit(:,:,3) = hv(halo+1:n_res+halo, halo+1:n_res+halo) ./ Minit(:,:,1);

% Calculate initial timestep
max_wave_speed = 0;
for i = 1:n_res
for j = 1:n_res
    wave_speed = sqrt(hu(i,j).^2 + hv(i,j).^2) / h(i,j) + sqrt(gravity * h(i,j));
    
    if (wave_speed > max_wave_speed)
        max_wave_speed = wave_speed;
    end
end
end

dt = cfl * dx / max_wave_speed;

% Number of timesteps
nt = ceil(t_final / dt);

% Output timestep information
disp(sprintf('Time step size: %1.2f seconds', dt));
disp(sprintf('Number of time steps: %i', nt));

% Time at each timestep
T = zeros(nt,1);

% Allocate space for intermediate steps
M = zeros(n_res, n_res, 3, nt);

% Loop through all timesteps
t_current = 0;

for t = 1:nt

    % Special handling of final timestep
    if (t == nt)
        dt = t_final - t_current;
    end
    
    % Apply boundary conditions
    h  = ApplyBoundaryConditions(h, n_res, halo);
    hu = ApplyBoundaryConditions(hu, n_res, halo);
    hv = ApplyBoundaryConditions(hv, n_res, halo);

    % Store old variables
    h_old = h;
    hu_old = hu;
    hv_old = hv;

    % Calculate X fluxes (predictor step)
    XFlux = CalculateXFluxes(h, hu, hv, n_res, halo);
    YFlux = CalculateYFluxes(h, hu, hv, n_res, halo);

    % Update half a timestep
    [h,hu,hv] = UpdateXFluxes(h, hu, hv, XFlux, n_res, halo, dx, 0.5 * dt);
    [h,hu,hv] = UpdateYFluxes(h, hu, hv, YFlux, n_res, halo, dx, 0.5 * dt);

    % Apply source terms
    for i = 1:n_res
    for j = 1:n_res
        io = i + halo;
        jo = j + halo;

        hu(io,jo) = hu(io,jo) + 0.5 * dt * f * hv_old(io,jo);
        hv(io,jo) = hv(io,jo) - 0.5 * dt * f * hu_old(io,jo);
    end
    end

    % Apply boundary conditions
    h  = ApplyBoundaryConditions(h, n_res, halo);
    hu = ApplyBoundaryConditions(hu, n_res, halo);
    hv = ApplyBoundaryConditions(hv, n_res, halo);

    % Calculate new X fluxes
    XFlux = CalculateXFluxes(h, hu, hv, n_res, halo);
    YFlux = CalculateYFluxes(h, hu, hv, n_res, halo);

    % Update a full timestep (corrector step)
    h = h_old;
    hu = hu_old;
    hv = hv_old;

    [h,hu,hv] = UpdateXFluxes(h, hu, hv, XFlux, n_res, halo, dx, dt);
    [h,hu,hv] = UpdateYFluxes(h, hu, hv, YFlux, n_res, halo, dx, dt);

    % Apply source terms
    for i = 1:n_res
    for j = 1:n_res
        io = i + halo;
        jo = j + halo;

        hu(io,jo) = hu(io,jo) + dt * f * hv_old(io,jo);
        hv(io,jo) = hv(io,jo) - dt * f * hu_old(io,jo);
    end
    end

    % Update current time
    t_current = t_current + dt;

    % Store current time
    T(t) = t_current;

    % Store updated data
    M(:,:,1,t) = h(halo+1:n_res+halo,halo+1:n_res+halo);
    M(:,:,2,t) = hu(halo+1:n_res+halo,halo+1:n_res+halo) ./ h(halo+1:n_res+halo,halo+1:n_res+halo);
    M(:,:,3,t) = hv(halo+1:n_res+halo,halo+1:n_res+halo) ./ h(halo+1:n_res+halo,halo+1:n_res+halo);

    M(:,:,1,t) = M(:,:,1,t)';
    M(:,:,2,t) = M(:,:,2,t)';
    M(:,:,3,t) = M(:,:,3,t)';
end

end

% Apply boundary conditions to variable x
function x = ApplyBoundaryConditions(x, n_res, halo)
    x(1:halo,:) = x(n_res+1:n_res+halo,:);
    x(n_res+halo+1:n_res+2*halo,:) = x(halo+1:2*halo,:);
    x(:,1:halo) = x(:,n_res+1:n_res+halo);
    x(:,n_res+halo+1:n_res+2*halo) = x(:,halo+1:2*halo);
end

% Calculate fluxes in the X direction
function F = CalculateXFluxes(h, hu, hv, n_res, halo)

    global gravity

    % Allocate fluxes
    F = zeros(n_res+1, n_res, 3);

    % Loop through all edges of constant X
    for i = 1:n_res+1
    for j = 1:n_res
        
        io = i + halo;
        jo = j + halo;

        % Calculate left state (to 2nd order accuracy)
        h_left  = 0.25 * h(io,jo)  + h(io-1,jo)  - 0.25 * h(io-2,jo);
        hu_left = 0.25 * hu(io,jo) + hu(io-1,jo) - 0.25 * hu(io-2,jo);
        hv_left = 0.25 * hv(io,jo) + hv(io-1,jo) - 0.25 * hv(io-2,jo);

        % Calculate right state (to 2nd order accuracy)
        h_right  = 0.25 * h(io-1,jo)  + h(io,jo)  - 0.25 * h(io+1,jo);
        hu_right = 0.25 * hu(io-1,jo) + hu(io,jo) - 0.25 * hu(io+1,jo);
        hv_right = 0.25 * hv(io-1,jo) + hv(io,jo) - 0.25 * hv(io+1,jo);

        % Left flux
        h_flux_left = hu_left;
        hu_flux_left = hu_left * hu_left / h_left + 0.5 * gravity * h_left * h_left;
        hv_flux_left = hu_left * hv_left / h_left;

        h_flux_right = hu_right;
        hu_flux_right = hu_right * hu_right / h_right + 0.5 * gravity * h_right * h_right;
        hv_flux_right = hu_right * hv_right / h_right;

        % Max wave speed
        max_wave_speed = abs(0.5 * (hu_left / h_left + hu_right / h_right)) ...
            + sqrt(gravity * 0.5 * (h_left + h_right));

        % Calculate flux
        F(i,j,1) = 0.5 * ( h_flux_left +  h_flux_right) - 0.5 * max_wave_speed * (h_right - h_left);
        F(i,j,2) = 0.5 * (hu_flux_left + hu_flux_right) - 0.5 * max_wave_speed * (hu_right - hu_left);
        F(i,j,3) = 0.5 * (hv_flux_left + hv_flux_right) - 0.5 * max_wave_speed * (hv_right - hv_left);
    end
    end
end

% Update X fluxes
function [h,hu,hv] = UpdateXFluxes(h, hu, hv, XFlux, n_res, halo, dx, dt)
    for i = 1:n_res+1
    for j = 1:n_res
        io = i + halo;
        jo = j + halo;

        h(io,jo)   = h(io,jo)   + dt / dx * XFlux(i,j,1);
        h(io-1,jo) = h(io-1,jo) - dt / dx * XFlux(i,j,1);

        hu(io,jo)   = hu(io,jo)   + dt / dx * XFlux(i,j,2);
        hu(io-1,jo) = hu(io-1,jo) - dt / dx * XFlux(i,j,2);

        hv(io,jo)   = hv(io,jo)   + dt / dx * XFlux(i,j,3);
        hv(io-1,jo) = hv(io-1,jo) - dt / dx * XFlux(i,j,3);
    end
    end
end

% Calculate fluxes in the Y direction
function F = CalculateYFluxes(h, hu, hv, n_res, halo)

    global gravity

    % Allocate fluxes
    F = zeros(n_res, n_res+1, 3);

    % Loop through all edges of constant Y
    for i = 1:n_res
    for j = 1:n_res+1
        
        io = i + halo;
        jo = j + halo;

        % Calculate left state (to 2nd order accuracy)
        h_left  = 0.25 * h(io,jo)  + h(io,jo-1)  - 0.25 * h(io,jo-2);
        hu_left = 0.25 * hu(io,jo) + hu(io,jo-1) - 0.25 * hu(io,jo-2);
        hv_left = 0.25 * hv(io,jo) + hv(io,jo-1) - 0.25 * hv(io,jo-2);

        % Calculate right state (to 2nd order accuracy)
        h_right  = 0.25 * h(io,jo-1)  + h(io,jo)  - 0.25 * h(io,jo+1);
        hu_right = 0.25 * hu(io,jo-1) + hu(io,jo) - 0.25 * hu(io,jo+1);
        hv_right = 0.25 * hv(io,jo-1) + hv(io,jo) - 0.25 * hv(io,jo+1);

        % Left flux
        h_flux_left = hv_left;
        hu_flux_left = hu_left * hv_left / h_left;
        hv_flux_left = hv_left * hv_left / h_left + 0.5 * gravity * h_left * h_left;

        h_flux_right = hv_right;
        hu_flux_right = hu_right * hv_right / h_right;
        hv_flux_right = hv_right * hv_right / h_right + 0.5 * gravity * h_right * h_right;

        % Max wave speed
        max_wave_speed = abs(0.5 * (hv_left / h_left + hv_right / h_right)) ...
            + sqrt(gravity * 0.5 * (h_left + h_right));

        % Calculate flux
        F(i,j,1) = 0.5 * ( h_flux_left +  h_flux_right) - 0.5 * max_wave_speed * (h_right - h_left);
        F(i,j,2) = 0.5 * (hu_flux_left + hu_flux_right) - 0.5 * max_wave_speed * (hu_right - hu_left);
        F(i,j,3) = 0.5 * (hv_flux_left + hv_flux_right) - 0.5 * max_wave_speed * (hv_right - hv_left);
    end
    end
end

% Update Y fluxes
function [h,hu,hv] = UpdateYFluxes(h, hu, hv, YFlux, n_res, halo, dx, dt)
    for i = 1:n_res
    for j = 1:n_res+1
        io = i + halo;
        jo = j + halo;

        h(io,jo)   = h(io,jo)   + dt / dx * YFlux(i,j,1);
        h(io,jo-1) = h(io,jo-1) - dt / dx * YFlux(i,j,1);

        hu(io,jo)   = hu(io,jo)   + dt / dx * YFlux(i,j,2);
        hu(io,jo-1) = hu(io,jo-1) - dt / dx * YFlux(i,j,2);

        hv(io,jo)   = hv(io,jo)   + dt / dx * YFlux(i,j,3);
        hv(io,jo-1) = hv(io,jo-1) - dt / dx * YFlux(i,j,3);
    end
    end
end